Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

from tests.utils import Payload, empty_strided, randint_strided, randn_strided

_INT_DTYPES = (
torch.int16,
torch.uint16,
torch.int32,
torch.uint32,
torch.int64,
torch.uint64,
_INT_DTYPES = (torch.int16, torch.int32, torch.int64)

_UINT_DTYPES = tuple(
filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64)))
)


Expand Down Expand Up @@ -38,18 +35,13 @@
(torch.float32, 1e-7, 1e-7),
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-2, 5e-3),
(torch.int16, 0, 0),
(torch.uint16, 0, 0),
(torch.int32, 0, 0),
(torch.uint32, 0, 0),
(torch.int64, 0, 0),
(torch.uint64, 0, 0),
),
)
+ tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES),
)
def test_add(
shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol
):
if dtype in _INT_DTYPES:
if dtype in _INT_DTYPES or dtype in _UINT_DTYPES:
input = randint_strided(
0, 100, shape, input_strides, dtype=dtype, device=device
)
Expand All @@ -72,10 +64,10 @@ def _add(input, other, out):


def _torch_add(input, other, out):
if input.dtype in (torch.uint16, torch.uint32, torch.uint64):
if input.dtype in _UINT_DTYPES:
input = input.to(torch.int64)

if other.dtype in (torch.uint16, torch.uint32, torch.uint64):
if other.dtype in _UINT_DTYPES:
other = other.to(torch.int64)

res = torch.add(input, other)
Expand Down
17 changes: 16 additions & 1 deletion tests/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,19 @@ def _rms_norm(input, weight, *, eps=1e-6, out=None):


def _torch_rms_norm(input, weight, *, eps=1e-6, out=None):
return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps)
# Fallback for `torch<2.3`: `rms_norm = (x / sqrt(mean(x^2) + eps)) * weight`.
def _fallback(input, _normalized_shape, weight, *, eps=1e-6):
rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps)

return (input / rms) * weight

rms_norm_fn = getattr(torch.nn.functional, "rms_norm", _fallback)

result = rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps)

if out is not None:
out.copy_(result)
else:
out = result

return out