diff --git a/tests/test_add.py b/tests/test_add.py index 1c98d91..d1ea0f8 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -4,13 +4,10 @@ from tests.utils import Payload, empty_strided, randint_strided, randn_strided -_INT_DTYPES = ( - torch.int16, - torch.uint16, - torch.int32, - torch.uint32, - torch.int64, - torch.uint64, +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) ) @@ -38,18 +35,13 @@ (torch.float32, 1e-7, 1e-7), (torch.float16, 1e-3, 1e-3), (torch.bfloat16, 1e-2, 5e-3), - (torch.int16, 0, 0), - (torch.uint16, 0, 0), - (torch.int32, 0, 0), - (torch.uint32, 0, 0), - (torch.int64, 0, 0), - (torch.uint64, 0, 0), - ), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), ) def test_add( shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol ): - if dtype in _INT_DTYPES: + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: input = randint_strided( 0, 100, shape, input_strides, dtype=dtype, device=device ) @@ -72,10 +64,10 @@ def _add(input, other, out): def _torch_add(input, other, out): - if input.dtype in (torch.uint16, torch.uint32, torch.uint64): + if input.dtype in _UINT_DTYPES: input = input.to(torch.int64) - if other.dtype in (torch.uint16, torch.uint32, torch.uint64): + if other.dtype in _UINT_DTYPES: other = other.to(torch.int64) res = torch.add(input, other) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index f447091..d6d4dff 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -59,4 +59,19 @@ def _rms_norm(input, weight, *, eps=1e-6, out=None): def _torch_rms_norm(input, weight, *, eps=1e-6, out=None): - return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps) + # Fallback for `torch<2.3`: `rms_norm = (x / sqrt(mean(x^2) + eps)) * weight`. + def _fallback(input, _normalized_shape, weight, *, eps=1e-6): + rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps) + + return (input / rms) * weight + + rms_norm_fn = getattr(torch.nn.functional, "rms_norm", _fallback) + + result = rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps) + + if out is not None: + out.copy_(result) + else: + out = result + + return out