From e5325a48ec8a1236f196fab6609c895fd631912c Mon Sep 17 00:00:00 2001 From: Arseni Ivanov Date: Sat, 7 Mar 2026 22:31:44 +0100 Subject: [PATCH 1/2] Pytorch only reference for mxfp4 to avoid aiter compile time on each new github actions job --- problems/amd_202602/mxfp4-mm/reference.py | 233 ++++++++++++++-------- 1 file changed, 153 insertions(+), 80 deletions(-) diff --git a/problems/amd_202602/mxfp4-mm/reference.py b/problems/amd_202602/mxfp4-mm/reference.py index 3c26d348..0f08cbb9 100644 --- a/problems/amd_202602/mxfp4-mm/reference.py +++ b/problems/amd_202602/mxfp4-mm/reference.py @@ -5,104 +5,177 @@ import torch from task import input_t, output_t from utils import make_match_reference -from aiter import QuantType,dtypes -import aiter -from aiter.ops.shuffle import shuffle_weight + # K must be divisible by 64 (scale group 32 and fp4 pack 2) SCALE_GROUP_SIZE = 32 -def generate_input(m: int, n: int, k: int, seed: int):# -> input_t: +def shuffle_weights(x: torch.Tensor, layout=(16, 16)) -> torch.Tensor: + """ + Pure PyTorch memory swizzle for weights. + Assumes x is already a standard uint8 tensor. """ - Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, shuffled B and B_scale. + IN, IK = layout + BK = IK * 2 + + # uint8 element_size is 1 byte, so K = 16 + K = 16 + BN = IN + + assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN}" + assert x.shape[-1] % BK == 0, f"{x.shape[-1]} % {BK} == {x.shape[-1] % BK}" + + #ported from https://github.com/ROCm/aiter/blob/main/aiter/ops/shuffle.py + x_ = x.view(-1, x.shape[-2] // BN, BN, x.shape[-1] // BK, BK // K, K) + x_ = x_.permute(0, 1, 3, 4, 2, 5) + x_ = x_.contiguous() + x_ = x_.view(*x.shape) + + return x_ + +#does shuffle of scaled similar to the end of the triton kernel from here https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py to be consistent with aiter data creation for input generation +def shuffle_scales(bs_e8m0: torch.Tensor, M_actual: int, N_actual: int) -> torch.Tensor: + M_alloc = ((M_actual + 255) // 256) * 256 + scaleM_pad = ((M_actual + 31) // 32) * 32 + scaleN_valid = (N_actual + 31) // 32 + scaleN = ((scaleN_valid + 7) // 8) * 8 + + bs_e8m0_padded = torch.full((scaleM_pad, scaleN), 127, dtype=torch.uint8, device=bs_e8m0.device) + bs_e8m0_padded[:M_actual, :scaleN_valid] = bs_e8m0 + + m = torch.arange(scaleM_pad, device=bs_e8m0.device)[:, None] + n = torch.arange(scaleN, device=bs_e8m0.device)[None, :] + + bs_offs_0 = m // 32 + bs_offs_1 = (m % 32) // 16 + bs_offs_2 = (m % 32) % 16 + + bs_offs_3 = n // 8 + bs_offs_4 = (n % 8) // 4 + bs_offs_5 = (n % 8) % 4 + + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 4 + + bs_offs_5 * 64 + + bs_offs_3 * 256 + + bs_offs_0 * 32 * scaleN + ) + + shuffled_flat = torch.full((M_alloc * scaleN,), 127, dtype=torch.uint8, device=bs_e8m0.device) + shuffled_flat[bs_offs.flatten()] = bs_e8m0_padded.flatten() + + return shuffled_flat.view(M_alloc, scaleN) - Returns: - Tuple of (A, B), both bf16 on cuda. +def generate_input(m: int, n: int, k: int, seed: int): + """ + Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, + shuffled B and B_scale. All natively in PyTorch. """ assert k % 64 == 0, "k must be divisible by 64 (scale group 32 and fp4 pack 2)" gen = torch.Generator(device="cuda") gen.manual_seed(seed) + A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda", generator=gen) B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda", generator=gen) - # quantized mxfp4 B - quant_func = aiter.get_triton_quant(QuantType.per_1x32) - B_q, B_scale_sh = quant_func(B, shuffle=True) + B_q, B_scale = quantize_mxfp4_pure_torch(B) + B_shuffle = shuffle_weights(B_q, layout=(16, 16)) + B_scale_sh = shuffle_scales(B_scale, M_actual=n, N_actual=k) - # shuffle B(weight) to (16,16) tile coalesced - B_shuffle = shuffle_weight(B_q, layout=(16, 16)) return (A, B, B_q, B_shuffle, B_scale_sh) -def run_torch_fp4_mm( - x: torch.Tensor, - w: torch.Tensor, - x_scales: torch.Tensor, - w_scales: torch.Tensor, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - PyTorch reference: dequant MXFP4 + E8M0 scale -> f32 -> mm -> dtype. - Same logic as aiter op_tests/test_gemm_a4w4.run_torch. - x: [m, k//2] fp4 packed, w: [n, k//2] fp4 packed - x_scales: [m, k//32] E8M0, w_scales: [n, k//32] E8M0 - Returns: [m, n] in dtype - """ - from aiter.utility import fp4_utils - - m, _ = x.shape - n, _ = w.shape - # fp4 packed -> f32 - x_f32 = fp4_utils.mxfp4_to_f32(x) - w_f32 = fp4_utils.mxfp4_to_f32(w) - # E8M0 scale: [*, k//32] -> repeat 32 along k -> f32 - x_scales = x_scales[:m].repeat_interleave(SCALE_GROUP_SIZE, dim=1) - x_scales_f32 = fp4_utils.e8m0_to_f32(x_scales) - x_f32 = x_f32 * x_scales_f32 - w_scales = w_scales[:n].repeat_interleave(SCALE_GROUP_SIZE, dim=1) - w_scales_f32 = fp4_utils.e8m0_to_f32(w_scales) - w_f32 = w_f32 * w_scales_f32 - return torch.mm(x_f32, w_f32.T).to(dtype)[:m, :n] +# type helpers from https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py in pytorch + +def e8m0_to_f32(scale_e8m0_biased: torch.Tensor) -> torch.Tensor: + scale_e8m0_biased = scale_e8m0_biased.contiguous().view(torch.uint8) + zero_case = scale_e8m0_biased == 0 + nan_case = scale_e8m0_biased == 0xFF + + scale_f32 = scale_e8m0_biased.to(torch.int32) << 23 + scale_f32[zero_case] = 0x00400000 + scale_f32[nan_case] = 0x7F800001 + return scale_f32.view(torch.float32) + +def mxfp4_to_f32(x: torch.Tensor) -> torch.Tensor: + x = x.contiguous().view(torch.uint8) + x = x.repeat_interleave(2, dim=-1) + x[..., ::2] = x[..., ::2] & 0xF + x[..., 1::2] = x[..., 1::2] >> 4 + + mxfp4_list = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + ] + mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device=x.device) + return mxfp4_in_f32[x.long()] + +#ported from _dynamic_mxfp4_quant_kernel_asm_layout in https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py +def quantize_mxfp4(x: torch.Tensor): + M, N = x.shape + x_blocks = x.view(M, N // 32, 32).float() + + amax = x_blocks.abs().max(dim=-1, keepdim=True).values + amax_i32 = amax.view(torch.int32) + amax_i32 = (amax_i32 + 2097152) & ~8388607 + amax_rounded = amax_i32.view(torch.float32) + + log_amax = torch.log2(amax_rounded) + scale_e8m0_unbiased = torch.floor(log_amax) - 2 + scale_e8m0_unbiased = torch.nan_to_num(scale_e8m0_unbiased, neginf=-127.0) + scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -127, 127) + + bs_e8m0 = (scale_e8m0_unbiased + 127).to(torch.uint8) + quant_scale = torch.exp2(-scale_e8m0_unbiased) + + qx = x_blocks * quant_scale + qx_i32 = qx.view(torch.int32) + + e = (qx_i32 >> 23) & 0xFF + m = qx_i32 & 0x7FFFFF + s_bit = (qx < 0).to(torch.uint8) << 3 + + E8_BIAS, E2_BIAS = 127, 1 + adjusted_exponents = E8_BIAS - (e + 1) + shift_amount = torch.clamp(adjusted_exponents, min=0, max=31) + + denorm_m = (4194304 | (m >> 1)) >> shift_amount + m = torch.where(e < E8_BIAS, denorm_m, m) + + e = torch.maximum(e, torch.tensor(E8_BIAS - E2_BIAS)) - (E8_BIAS - E2_BIAS) + + e2m1_tmp = torch.minimum((((e << 2) | (m >> 21)) + 1) >> 1, torch.tensor(0x7)) + e2m1_value = s_bit | e2m1_tmp.to(torch.uint8) + e2m1_value = e2m1_value.view(M, N) + evens = e2m1_value[:, 0::2] + odds = e2m1_value[:, 1::2] + x_fp4 = evens | (odds << 4) + + return x_fp4, bs_e8m0.view(M, N // 32) def ref_kernel(data: input_t) -> output_t: """ - Reference: MXFP4 per-1x32 quant on A and B; both PyTorch ref and gemm_a4w4 are given. - Returns gemm_a4w4 for check_implementation. + Main reference entry point. Bypasses Aiter shuffles by quantizing the pristine + unquantized inputs natively. """ - A, B, B_q, B_shuffle, B_scale_sh = data - A = A.contiguous() - B = B.contiguous() - m, k = A.shape - n, _ = B.shape - - # 1) PyTorch impl just for your reference: dequant fp4 + e8m0 -> f32 -> mm -> bf16 - # Per-1x32 MXFP4 quant - # quant_func = aiter.get_triton_quant(QuantType.per_1x32) - # quant_func(x, shuffle=False) -> (dtypes.fp4x2, scale); scale layout matches gemm_a4w4 - # A_q, A_scale = quant_func(A, shuffle=False) - # B_q, B_scale = quant_func(B, shuffle=False) - - # gemm_a4w4 expects A [M,K/2], B [N,K/2] as dtypes.fp4x2; A_scale/B_scale [*,K/32] E8M0 - # quant_func returns scale as dtypes.fp8_e8m0; gemm_a4w4 accepts E8M0, no view to uint8 needed - # slice to exact shapes [m,k_scale] / [n,k_scale] (quant may return padded scale) - - # k_scale = k // SCALE_GROUP_SIZE - # A_scale = A_scale[:m, :k_scale].contiguous() - # B_scale = B_scale[:n, :k_scale].contiguous() - # out_torch = run_torch_fp4_mm(A_q, B_q, A_scale, B_scale, torch.bfloat16) - - # 2) aiter.gemm_a4w4 path: needs shuffled B_q and shuffled scales (see test_gemm_a4w4.py:102-105) - # Per-1x32 MXFP4 quant - quant_func = aiter.get_triton_quant(QuantType.per_1x32) - A_q, A_scale_sh = quant_func(A, shuffle=True) - # to be noted, aiter also has other a4w4 implements using triton, https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py - out_gemm = aiter.gemm_a4w4( - A_q, - B_shuffle, - A_scale_sh, - B_scale_sh, - dtype=dtypes.bf16, - bpreshuffle=True, - ) - return out_gemm + A, B, *_ = data + + A_q, A_scale = quantize_mxfp4(A) + B_q, B_scale = quantize_mxfp4(B) + + M, _ = A.shape + N, _ = B.shape + + # Dequantize back to f32 for matmul similar to old reference + x_f32 = mxfp4_to_f32(A_q) + x_scales = A_scale.repeat_interleave(SCALE_GROUP_SIZE, dim=1) + x_f32 = x_f32 * e8m0_to_f32(x_scales) + + w_f32 = mxfp4_to_f32(B_q) + w_scales = B_scale.repeat_interleave(SCALE_GROUP_SIZE, dim=1) + w_f32 = w_f32 * e8m0_to_f32(w_scales) + + return torch.mm(x_f32, w_f32.T).to(A.dtype)[:M, :N] -check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) \ No newline at end of file +check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) From 839e15d5794a404bd202e8c5ca915de2fee34c7b Mon Sep 17 00:00:00 2001 From: Arseni Ivanov Date: Sun, 8 Mar 2026 09:11:16 +0100 Subject: [PATCH 2/2] Fixed function signature --- problems/amd_202602/mxfp4-mm/reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/problems/amd_202602/mxfp4-mm/reference.py b/problems/amd_202602/mxfp4-mm/reference.py index 0f08cbb9..190920a5 100644 --- a/problems/amd_202602/mxfp4-mm/reference.py +++ b/problems/amd_202602/mxfp4-mm/reference.py @@ -79,7 +79,7 @@ def generate_input(m: int, n: int, k: int, seed: int): A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda", generator=gen) B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda", generator=gen) - B_q, B_scale = quantize_mxfp4_pure_torch(B) + B_q, B_scale = quantize_mxfp4(B) B_shuffle = shuffle_weights(B_q, layout=(16, 16)) B_scale_sh = shuffle_scales(B_scale, M_actual=n, N_actual=k)