From e14a817963eb702b4e51237a5c2228721ea75f1d Mon Sep 17 00:00:00 2001 From: zhangyue Date: Fri, 13 Mar 2026 01:42:34 +0000 Subject: [PATCH 1/6] fix: remove uint16 test from test_add.py - Removed `torch.uint16` from the list of integer data types in the `_INT_DTYPES` tuple to streamline the code and eliminate redundancy. --- tests/test_add.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index 1c98d91..ef807d2 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -6,7 +6,6 @@ _INT_DTYPES = ( torch.int16, - torch.uint16, torch.int32, torch.uint32, torch.int64, @@ -39,7 +38,6 @@ (torch.float16, 1e-3, 1e-3), (torch.bfloat16, 1e-2, 5e-3), (torch.int16, 0, 0), - (torch.uint16, 0, 0), (torch.int32, 0, 0), (torch.uint32, 0, 0), (torch.int64, 0, 0), From 32109da0f52cedef358a6c285f4fef146f00de7e Mon Sep 17 00:00:00 2001 From: zhangyue Date: Fri, 13 Mar 2026 01:53:20 +0000 Subject: [PATCH 2/6] refactor: enhance dtype handling in test_add.py --- tests/test_add.py | 55 +++++++++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index ef807d2..cdcba7e 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -4,15 +4,33 @@ from tests.utils import Payload, empty_strided, randint_strided, randn_strided -_INT_DTYPES = ( - torch.int16, - torch.int32, - torch.uint32, - torch.int64, - torch.uint64, +_INT_DTYPES = tuple( + d + for d in ( + torch.int16, + torch.int32, + getattr(torch, "uint32", None), + torch.int64, + getattr(torch, "uint64", None), + ) + if d is not None ) +def _dtype_parametrize(): + candidates = [ + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + (torch.int16, 0, 0), + (torch.int32, 0, 0), + (getattr(torch, "uint32", None), 0, 0), + (torch.int64, 0, 0), + (getattr(torch, "uint64", None), 0, 0), + ] + return tuple((d, r, a) for (d, r, a) in candidates if d is not None) + + @pytest.mark.auto_act_and_assert @pytest.mark.parametrize( "shape, input_strides, other_strides, out_strides", @@ -31,19 +49,7 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -@pytest.mark.parametrize( - ("dtype", "rtol", "atol"), - ( - (torch.float32, 1e-7, 1e-7), - (torch.float16, 1e-3, 1e-3), - (torch.bfloat16, 1e-2, 5e-3), - (torch.int16, 0, 0), - (torch.int32, 0, 0), - (torch.uint32, 0, 0), - (torch.int64, 0, 0), - (torch.uint64, 0, 0), - ), -) +@pytest.mark.parametrize(("dtype", "rtol", "atol"), _dtype_parametrize()) def test_add( shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol ): @@ -69,11 +75,18 @@ def _add(input, other, out): return out +_UINT_DTYPES = tuple( + d + for name in ("uint16", "uint32", "uint64") + if (d := getattr(torch, name, None)) is not None +) + + def _torch_add(input, other, out): - if input.dtype in (torch.uint16, torch.uint32, torch.uint64): + if input.dtype in _UINT_DTYPES: input = input.to(torch.int64) - if other.dtype in (torch.uint16, torch.uint32, torch.uint64): + if other.dtype in _UINT_DTYPES: other = other.to(torch.int64) res = torch.add(input, other) From cbe3e77b8e3a6e2c8a30bcadc8be577b9c10500c Mon Sep 17 00:00:00 2001 From: zhangyue Date: Fri, 13 Mar 2026 02:06:25 +0000 Subject: [PATCH 3/6] refactor: streamline dtype parameterization in test_add.py and enhance rms_norm fallback handling in test_rms_norm.py --- tests/test_rms_norm.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index f447091..b0c9c5d 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -59,4 +59,13 @@ def _rms_norm(input, weight, *, eps=1e-6, out=None): def _torch_rms_norm(input, weight, *, eps=1e-6, out=None): - return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps) + rms_norm_fn = getattr(torch.nn.functional, "rms_norm", None) + if rms_norm_fn is not None: + return rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps) + # Fallback for PyTorch < 2.3: RMS norm = (x / sqrt(mean(x^2) + eps)) * weight + rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps) + result = (input / rms) * weight + if out is not None: + out.copy_(result) + return out + return result From 469679003bfec6dd1bf8f5d8c6e4117b803ad775 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Fri, 13 Mar 2026 02:39:02 +0000 Subject: [PATCH 4/6] refactor: add unsigned integer data types to test_add.py for enhanced dtype handling --- tests/test_add.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index cdcba7e..8e6863e 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -9,8 +9,16 @@ for d in ( torch.int16, torch.int32, - getattr(torch, "uint32", None), torch.int64, + ) + if d is not None +) + +_UINT_DTYPES = tuple( + d + for d in ( + getattr(torch, "uint16", None), + getattr(torch, "uint32", None), getattr(torch, "uint64", None), ) if d is not None @@ -53,7 +61,7 @@ def _dtype_parametrize(): def test_add( shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol ): - if dtype in _INT_DTYPES: + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: input = randint_strided( 0, 100, shape, input_strides, dtype=dtype, device=device ) @@ -75,13 +83,6 @@ def _add(input, other, out): return out -_UINT_DTYPES = tuple( - d - for name in ("uint16", "uint32", "uint64") - if (d := getattr(torch, name, None)) is not None -) - - def _torch_add(input, other, out): if input.dtype in _UINT_DTYPES: input = input.to(torch.int64) From 6ba0eab30e3a26ca4779f2a34df2b9b72226f276 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 19 Mar 2026 05:44:46 +0000 Subject: [PATCH 5/6] refactor: simplify integer dtype filtering --- tests/test_add.py | 42 +++++++++++------------------------------- 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index 8e6863e..d1ea0f8 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -4,41 +4,13 @@ from tests.utils import Payload, empty_strided, randint_strided, randn_strided -_INT_DTYPES = tuple( - d - for d in ( - torch.int16, - torch.int32, - torch.int64, - ) - if d is not None -) +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) _UINT_DTYPES = tuple( - d - for d in ( - getattr(torch, "uint16", None), - getattr(torch, "uint32", None), - getattr(torch, "uint64", None), - ) - if d is not None + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) ) -def _dtype_parametrize(): - candidates = [ - (torch.float32, 1e-7, 1e-7), - (torch.float16, 1e-3, 1e-3), - (torch.bfloat16, 1e-2, 5e-3), - (torch.int16, 0, 0), - (torch.int32, 0, 0), - (getattr(torch, "uint32", None), 0, 0), - (torch.int64, 0, 0), - (getattr(torch, "uint64", None), 0, 0), - ] - return tuple((d, r, a) for (d, r, a) in candidates if d is not None) - - @pytest.mark.auto_act_and_assert @pytest.mark.parametrize( "shape, input_strides, other_strides, out_strides", @@ -57,7 +29,15 @@ def _dtype_parametrize(): ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -@pytest.mark.parametrize(("dtype", "rtol", "atol"), _dtype_parametrize()) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), +) def test_add( shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol ): From 866d94ca17b2bfd8aca788dc3f51530897af9ceb Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 19 Mar 2026 05:54:36 +0000 Subject: [PATCH 6/6] refactor: simplify `_torch_rms_norm` fallback logic --- tests/test_rms_norm.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index b0c9c5d..d6d4dff 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -59,13 +59,19 @@ def _rms_norm(input, weight, *, eps=1e-6, out=None): def _torch_rms_norm(input, weight, *, eps=1e-6, out=None): - rms_norm_fn = getattr(torch.nn.functional, "rms_norm", None) - if rms_norm_fn is not None: - return rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps) - # Fallback for PyTorch < 2.3: RMS norm = (x / sqrt(mean(x^2) + eps)) * weight - rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps) - result = (input / rms) * weight + # Fallback for `torch<2.3`: `rms_norm = (x / sqrt(mean(x^2) + eps)) * weight`. + def _fallback(input, _normalized_shape, weight, *, eps=1e-6): + rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps) + + return (input / rms) * weight + + rms_norm_fn = getattr(torch.nn.functional, "rms_norm", _fallback) + + result = rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps) + if out is not None: out.copy_(result) - return out - return result + else: + out = result + + return out