diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..386efa16 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "base"] + path = base + url = https://github.com/pytorch/helion diff --git a/base b/base new file mode 160000 index 00000000..5eb88947 --- /dev/null +++ b/base @@ -0,0 +1 @@ +Subproject commit 5eb8894723e86f1e9f218459e7af06feb030212d diff --git a/problems/helion/causal_conv1d_py/reference.py b/problems/helion/causal_conv1d_py/reference.py index 0d2ae2f5..8276ed7a 100644 --- a/problems/helion/causal_conv1d_py/reference.py +++ b/problems/helion/causal_conv1d_py/reference.py @@ -1,3 +1,22 @@ +""" +Causal depthwise Conv1D — reference implementation. + +Used in SSM-based LLM architectures such as Mamba, where a short causal +(left-padded) depthwise convolution mixes local context along the sequence +dimension independently per channel, before the selective state-space step. + +"Causal" means output position t depends only on input positions <= t, +enforced by padding W-1 zeros on the left and no padding on the right. + +This module provides a pure-PyTorch reference against which optimized +Triton/CUDA kernels are verified. + +Shapes: + x : (B, D, S) — batch, channels (model dim), sequence length + weight : (D, W) — one filter of width W per channel (depthwise) + bias : (D,) — per-channel bias + output : (B, D, S) — same shape as input +""" import torch import torch.nn.functional as F from task import input_t, output_t @@ -5,6 +24,7 @@ def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t: + """Generate random (x, weight, bias) on CUDA with a fixed seed for reproducibility.""" gen = torch.Generator(device="cuda") gen.manual_seed(seed) x = torch.randn(B, D, S, dtype=torch.float32, device="cuda", generator=gen).contiguous() @@ -14,6 +34,13 @@ def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t: def ref_kernel(data: input_t) -> output_t: + """ + Causal depthwise Conv1D via PyTorch. + + Pads W-1 zeros on the left of the sequence so that the convolution at + position t sees only x[:, :, t-W+1 : t+1], preserving causality. + groups=D makes each channel use its own filter (depthwise). + """ with DeterministicContext(): x, weight, bias = data B, D, S = x.shape diff --git a/problems/helion/causal_conv1d_py/submission.py b/problems/helion/causal_conv1d_py/submission.py index 92716763..765a06e4 100644 --- a/problems/helion/causal_conv1d_py/submission.py +++ b/problems/helion/causal_conv1d_py/submission.py @@ -6,29 +6,23 @@ # Per-shape configs: map (B, D, S, W) to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. SHAPE_CONFIGS: dict[tuple, helion.Config] = { # Test shapes - (1, 64, 64, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (2, 128, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 256, 256, 3): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 128, 64, 8): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (4, 64, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 64, 64, 4): helion.Config(block_sizes=[1, 64], num_warps=2, num_stages=2), + (2, 128, 128, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=2), + (1, 256, 256, 3): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=2), + (1, 128, 64, 8): helion.Config(block_sizes=[1, 64], num_warps=2, num_stages=2), + (4, 64, 128, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=2), + # Non-benchmarked shapes (keep for safety) + (1, 768, 512, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3), + (1, 768, 2048, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3), # Benchmark shapes - (1, 768, 512, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 768, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 1536, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 2560, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 2560, 4096, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 1536, 2048, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3), + (1, 2560, 2048, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3), + (1, 2560, 4096, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=4), } -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/causal_conv_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. def _make_kernel(config: helion.Config): @helion.kernel(static_shapes=True, config=config) def kernel( @@ -46,20 +40,11 @@ def kernel( for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]): bi = rb.begin - acc1 = hl.zeros([rd, rs], dtype=torch.float32) - acc2 = hl.zeros([rd, rs], dtype=torch.float32) - acc3 = hl.zeros([rd, rs], dtype=torch.float32) + acc = hl.zeros([rd, rs], dtype=torch.float32) for j in range(W): - c1 = w[rd, j].to(torch.float32) - x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc1 = acc1 + x1 * c1[:, None] - c2 = w[rd, j].to(torch.float32) - x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc2 = acc2 + x2 * c2[:, None] - c3 = w[rd, j].to(torch.float32) - x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc3 = acc3 + x3 * c3[:, None] - acc = (acc1 + acc2 + acc3) / 3.0 + c = w[rd, j].to(torch.float32) + x_val = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) + acc = acc + x_val * c[:, None] acc = acc + b[rd].to(torch.float32)[:, None] y[rb, rd, rs] = acc[None, :, :].to(y.dtype) diff --git a/problems/helion/fp8_quant_py/prompt.md b/problems/helion/fp8_quant_py/prompt.md new file mode 100644 index 00000000..413f1dce --- /dev/null +++ b/problems/helion/fp8_quant_py/prompt.md @@ -0,0 +1,2 @@ +Goal: Maximize Helion Kernel in @submission.py; Ensure correctness using @reference.py. Details: Use @../../../base/examples/ that contains optimized helion examples. + @../../helion/, use "python eval.py test causal_conv1d_py/" to test correctness and "python eval.py benchmark causal_conv1d_py/" to measure performance. diff --git a/problems/helion/fp8_quant_py/reference.py b/problems/helion/fp8_quant_py/reference.py index bcad6943..c2f1b2ff 100644 --- a/problems/helion/fp8_quant_py/reference.py +++ b/problems/helion/fp8_quant_py/reference.py @@ -22,17 +22,21 @@ def ref_kernel(data: input_t) -> output_t: num_groups = x_s.shape[1] group_size = hidden_dim // num_groups + # convert to float32 for computation x_f32 = x.float() + # reshape into fp8 groups x_grouped = x_f32.reshape(num_tokens, num_groups, group_size) - # Per-group absmax + # Per-group absmax and clamp to mimimum fp8 value absmax = x_grouped.abs().amax(dim=-1).clamp(min=FP8_EPS) # Scale = absmax / fp8_max + # scale abs-max by maximum fp8 scale = absmax / FP8_MAX - # Quantize + # Quantize by dividing by scale and clamping to fp8 range quantized = (x_grouped / scale.unsqueeze(-1)).clamp(FP8_MIN, FP8_MAX) + # Reshape to original shape quantized = quantized.reshape(num_tokens, hidden_dim) x_q[...] = quantized diff --git a/problems/helion/fp8_quant_py/submission.py b/problems/helion/fp8_quant_py/submission.py index 4b562fa9..d2ddbb63 100644 --- a/problems/helion/fp8_quant_py/submission.py +++ b/problems/helion/fp8_quant_py/submission.py @@ -10,79 +10,71 @@ # Autotune locally for each shape, then paste the best config here. SHAPE_CONFIGS: dict[tuple, helion.Config] = { # Test shapes - (1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (4, 512, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (16, 1024, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (8, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), + (4, 512, 128): helion.Config(block_sizes=[2], num_warps=2, num_stages=1), + (16, 1024, 64): helion.Config(block_sizes=[4], num_warps=2, num_stages=1), + (1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), + (8, 4096, 128): helion.Config(block_sizes=[4], num_warps=2, num_stages=1), # Benchmark shapes - # (1, 4096, 128) already covered above - (16, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (256, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (256, 8192, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4096, 7168, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (16, 4096, 128): helion.Config(block_sizes=[4], num_warps=4, num_stages=2), + (256, 4096, 128): helion.Config(block_sizes=[8], num_warps=8, num_stages=2), + (256, 8192, 128): helion.Config(block_sizes=[8], num_warps=8, num_stages=2), + (4096, 7168, 128):helion.Config(block_sizes=[16], num_warps=8, num_stages=2), } -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/fp8_group_quant_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. def _make_kernel(config: helion.Config): @helion.kernel(static_shapes=True, config=config) def kernel( - data: torch.Tensor, # [N, G] input rows - scales_out: torch.Tensor, # [N] output normalization factors + data: torch.Tensor, # [N, G] float32 + scales_out: torch.Tensor, # [N] ) -> torch.Tensor: nrows = data.size(0) ncols = hl.specialize(data.size(1)) MAX_VAL = 448.0 - qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device) - for rr in hl.tile(nrows): row = data[rr, :].to(torch.float32) - - abs1 = torch.abs(row) - amax1 = torch.amax(abs1, -1) - abs2 = torch.abs(row) - amax2 = torch.amax(abs2, -1) - abs3 = torch.abs(row) - amax3 = torch.amax(abs3, -1) - amax = (amax1 + amax2 + amax3) / 3.0 - amax = torch.clamp(amax, min=1e-10) + amax = torch.amax(torch.abs(row), dim=-1).clamp(min=1e-10) scale = amax / MAX_VAL - - q1 = row / scale[:, None] - q2 = row / scale[:, None] - q3 = row / scale[:, None] - qout[rr, :] = (q1 + q2 + q3) / 3.0 + qout[rr, :] = torch.clamp(row / scale[:, None], -MAX_VAL, MAX_VAL) scales_out[rr] = scale - return qout return kernel +# Create a kernel for each shape based on the configs above. _KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} +# TODO Fuse reshapes into kernel. Reshape is a no-op for the GPU, but requires +# copy kernel in pytorch eager mode. + def custom_kernel(data: input_t) -> output_t: x, x_q, x_s = data + # num_tokens, hidden_dim = x.shape T, H = x.shape + # num_groups = x_s.shape[1] G = x_s.shape[1] + # group_size = hidden_dim // num_groups gsz = H // G + # merge num_tokens and num_groups into a single dimension for the kernel N = T * G + # Select the appropriate kernel based on the input shape. kernel = _KERNELS[(T, H, gsz)] + # Reshape inputs and scales for the kernel + # inputs is input argument + # scales is output argument, but we need to pass it in as an argument flat_in = x.reshape(N, gsz) flat_s = x_s.reshape(N) + # return value is quantized output flat_q = kernel(flat_in, flat_s) + # Reshape outputs back to original shapes x_q[...] = flat_q.reshape(T, H) x_s[...] = flat_s.reshape(T, G) return x_q, x_s diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py index 04e0ecfc..3b723bd2 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py @@ -9,17 +9,17 @@ # Autotune locally for each shape, then paste the best config here. SHAPE_CONFIGS: dict[tuple, helion.Config] = { # Test shapes - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 64, 2, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (2, 128, 4, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (1, 256, 4, 64, 128): helion.Config(block_sizes=[128], num_warps=4, num_stages=2), # Benchmark shapes - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 64, 1, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (2, 512, 3, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[100], num_warps=4, num_stages=2), + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[128], num_warps=4, num_stages=2), + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[128], num_warps=4, num_stages=2), + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), } @@ -28,7 +28,6 @@ # helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_h_0.acf") -# NOTE: This is an intentionally inefficient baseline implementation. def _make_kernel(config: helion.Config): @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) def kernel( @@ -46,40 +45,33 @@ def kernel( NT = (T + C - 1) // C h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device) v_out = torch.empty_like(u) + block_v = hl.register_block_size(V) - BH = B * H - - for flat, tv in hl.tile([BH, V], block_size=[1, 8]): - b_idx = flat.begin // H - h_idx = flat.begin % H - state = hl.zeros([K, tv], dtype=torch.float32) + for tile_b, tile_h, tile_v in hl.tile([B, H, V], block_size=[1, 1, block_v]): + b_idx = tile_b.id + h_idx = tile_h.id + state = hl.zeros([K, tile_v], dtype=torch.float32) for tc in hl.tile(T, block_size=C): - chunk_idx = tc.begin // C t_end = min(tc.begin + C, T) - 1 - h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype) + h_out[b_idx, tc.id, h_idx, :, tile_v] = state.to(k.dtype) - proj1 = hl.dot( - w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 - ) - proj2 = hl.dot( - w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 - ) - proj = (proj1 + proj2) * 0.5 - diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj - v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype) + b_w = w[b_idx, tc, h_idx, :] + proj = hl.dot(b_w, state.to(k.dtype), out_dtype=torch.float32) + p_v = u[b_idx, tc, h_idx, tile_v].to(torch.float32) + diff = p_v - proj + v_out[b_idx, tc, h_idx, tile_v] = diff.to(u.dtype) g_end = g[b_idx, t_end, h_idx].to(torch.float32) - g_t = g[b_idx, tc, h_idx].to(torch.float32) + b_g = g[b_idx, tc, h_idx].to(torch.float32) valid = tc.index < T - alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0) - k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None] + alpha = torch.where(valid, torch.exp(g_end - b_g), 0.0) + diff_gated = (diff * alpha[:, None]).to(k.dtype) state = state * torch.exp(g_end) - upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) - upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) - state = state + (upd1 + upd2) * 0.5 + b_k = k[b_idx, tc, h_idx, :] + state = hl.dot(b_k.T, diff_gated, acc=state) return h_out, v_out diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py index eb4de947..76cbeb0d 100644 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py +++ b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py @@ -9,17 +9,17 @@ # Autotune locally for each shape, then paste the best config here. SHAPE_CONFIGS: dict[tuple, helion.Config] = { # Test shapes - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check + (1, 64, 2, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (2, 128, 4, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (1, 256, 4, 64, 128): helion.Config(block_sizes=[128], num_warps=4, num_stages=2), # Benchmark shapes - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config + (1, 64, 1, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (2, 512, 3, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[100], num_warps=4, num_stages=2), + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[128], num_warps=4, num_stages=2), + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[128], num_warps=4, num_stages=2), + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2), } @@ -47,16 +47,16 @@ def kernel( out = torch.empty_like(v) - BH = B * H - for flat_bh, tile_t in hl.tile([BH, T], block_size=[1, C]): - b_idx = flat_bh.begin // H - h_idx = flat_bh.begin % H + block_v = hl.register_block_size(V) + for tile_b, tile_h, tile_t, tile_v in hl.tile([B, H, T, V], block_size=[1, 1, C, block_v]): + b_idx = tile_b.begin + h_idx = tile_h.begin c_idx = tile_t.begin // C g_vals = g[b_idx, tile_t, h_idx] q_tile = q[b_idx, tile_t, h_idx, :] k_tile = k[b_idx, tile_t, h_idx, :] - v_tile = v[b_idx, tile_t, h_idx, :] + v_tile = v[b_idx, tile_t, h_idx, tile_v] # intra-chunk: q @ k^T * exp(g_i - g_j), with causal mask qk = hl.dot(q_tile, k_tile.T) @@ -68,9 +68,9 @@ def kernel( # inter-chunk: (q @ h) * exp(g) q_s = q_tile * torch.exp(g_vals)[:, None] - global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) + global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, tile_v]) - out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) + out[b_idx, tile_t, h_idx, tile_v] = ((global_out + local_out) * scale).to(out.dtype) return out diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py index bd7c1507..6d02a6db 100644 --- a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py +++ b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py @@ -1,3 +1,100 @@ +""" +Gated DeltaNet: chunk-wise recomputation of w and u +==================================================== + +Background +---------- +DeltaNet is a linear recurrent model based on the **delta rule** (Widrow-Hoff, +1960), which updates a matrix-valued hidden state H_t ∈ R^{K×V} as: + + H_t = exp(g_t) · H_{t-1} + β_t · (v_t − H_{t-1} k_t) ⊗ k_t + +where: + k_t ∈ R^K — key (query used for the associative read) + v_t ∈ R^V — value (target to associate) + β_t ∈ (0,1) — per-step learning rate / write strength + g_t ≤ 0 — log-decay (gate); exp(g_t) ∈ (0,1] damps old memories + ⊗ — outer product + +The exp(g_t) term makes this the **Gated DeltaNet** (Yang et al., 2024). +The delta rule term `−H_{t-1} k_t ⊗ k_t` is an error-correction: it first +reads what the current memory says about k_t, then erases it before writing +v_t, so the net update is proportional to the prediction error (v_t − ŷ_t). + +Chunk-wise parallel form +------------------------ +To exploit GPU parallelism, the sequence of length T is split into NT = T/C +non-overlapping chunks of size C. Within chunk n (positions t = nC … nC+C−1) +define τ = t − nC as the within-chunk index. + +The chunk boundary state update can be written as: + + H_{n+1} = exp(G_n) · H_n + Σ_τ u_τ ⊗ k_τ − w_τ ⊗ k_τ · H_n + = exp(G_n) · H_n + (U_n − W_n) · H_n (matrix form) + +where G_n = Σ_τ g_{nC+τ} is the total chunk decay. The matrices U_n and W_n +collect the within-chunk "value writes" and "key absorptions" respectively, +and their efficient computation is the purpose of this file. + +Intra-chunk interaction matrix M +--------------------------------- +For positions i > j within the same chunk (strict causal order), position j's +delta-rule write creates a residual that position i must account for: + + M_{i,j} = β_i · exp(g̃_i − g̃_j) · (k_i k_j^T) (i > j, else 0) + +where g̃_t is the **chunk-local** cumulative sum of g (restarted at each chunk +boundary). The exponential exp(g̃_i − g̃_j) is the decay from position j to +position i within the chunk. + +The k_i k_j^T factor arises because the delta rule at step i reads k_i^T H, +and H already contains the write β_j k_j ⊗ k_j from step j, contributing +β_j (k_i^T k_j) k_j to the correction term. + +Solving for A = (I + M)^{-1} +------------------------------ +Within a chunk the delta-rule corrections chain together: the write at step j +is partially cancelled by step j+1, which is further modified by step j+2, +etc. Summing this geometric-series of interactions leads to a linear system: + + (I + M) · A = I ⟹ A = (I + M)^{-1} + +Because M is **strictly lower-triangular**, (I + M) is unit lower-triangular +and the solve is numerically stable. A is computed once per chunk and reused +for both w and u. + +Computing w and u +----------------- +Given the resolved interaction matrix A = (I + M)^{-1}: + + u_i = Σ_j A_{i,j} · β_j · v_j (shape: [K]) + w_i = Σ_j A_{i,j} · β_j · exp(g̃_j) · k_j (shape: [V]) + +In matrix form (over the C positions in a chunk): + + U = A · (β ⊙ v) (u_c in the code) + W = A · (β ⊙ exp(g̃) ⊙ k) (w_c in the code) + + u accumulates the net value contributions reaching each position after all + within-chunk delta-rule cancellations have been applied. + + w accumulates the net key absorptions; the exp(g̃_j) factor converts the + local-cumsum decay at j into the correct scale for the outer chunk loop + (where H_{n,0} is decayed by the full chunk decay to position i). + +These (w, u) pairs are the inputs consumed by the chunk-wise outer state- +transition loop of the full Gated DeltaNet forward pass. + +References +---------- +- Yang et al. "Gated Delta Networks: Improving Mamba2 with Delta Rule" (2024) + https://arxiv.org/abs/2412.06464 +- Widrow & Hoff. "Adaptive Switching Circuits" (1960) — original delta rule +- Schmidhuber. "Learning to Control Fast-Weight Memories" (1992) — fast-weight + programmers; modern framing of outer-product / delta-rule recurrences +- Sun et al. "Retentive Network" / "Flash Linear Attention" — chunk-wise + formulation of linear RNN state transitions used here +""" import torch from task import input_t, output_t from utils import verbose_allclose diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py index 07fb0691..81f07033 100644 --- a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py +++ b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py @@ -9,17 +9,17 @@ # Autotune locally for each shape, then paste the best config here. SHAPE_CONFIGS: dict[tuple, helion.Config] = { # Test shapes - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check + (1, 64, 2, 64, 64): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (2, 128, 4, 64, 64): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (1, 256, 4, 64, 128): helion.Config(block_sizes=[64, 128], num_warps=4, num_stages=3), # Benchmark shapes - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config + (1, 64, 1, 64, 64): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (2, 512, 3, 64, 64): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[100, 100], num_warps=4, num_stages=3), + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=3), } @@ -28,14 +28,13 @@ # helion.Config(..., advanced_controls_file="/opt/booster_pack/recompute_w_u_fwd_0.acf") -# NOTE: This is an intentionally inefficient baseline implementation. def _make_kernel(config: helion.Config): @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) def kernel( k: torch.Tensor, # [B, T, H, K] v: torch.Tensor, # [B, T, H, V] beta: torch.Tensor, # [B, T, H] - A: torch.Tensor, # [B, T, H, BT] + A: torch.Tensor, # [B, T, H, C] g: torch.Tensor, # [B, T, H] ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K = k.shape @@ -47,42 +46,26 @@ def kernel( w_out = torch.empty_like(k) u_out = torch.empty_like(v) - BH = B * H - for flat_bh, rt in hl.tile([BH, T], block_size=[1, C]): - b_idx = flat_bh.begin // H - h_idx = flat_bh.begin % H - - w_acc1 = hl.zeros([rt, K], dtype=torch.float32) - u_acc1 = hl.zeros([rt, V], dtype=torch.float32) - w_acc2 = hl.zeros([rt, K], dtype=torch.float32) - u_acc2 = hl.zeros([rt, V], dtype=torch.float32) - - for ci in range(C): - t_ci = rt.begin + ci - a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) - coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) - decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - - k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) - v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - - w_acc1 = w_acc1 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] - u_acc1 = u_acc1 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - - for ci in range(C - 1, -1, -1): - t_ci = rt.begin + ci - a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) - coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) - decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - - k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) - v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - - w_acc2 = w_acc2 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] - u_acc2 = u_acc2 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - - w_out[b_idx, rt, h_idx, :] = ((w_acc1 + w_acc2) * 0.5).to(k.dtype) - u_out[b_idx, rt, h_idx, :] = ((u_acc1 + u_acc2) * 0.5).to(v.dtype) + block_k = hl.register_block_size(K) + block_v = hl.register_block_size(V) + + # W = A_mat @ (β * exp(g̃) * k) + for tile_b, tile_h, rt, tile_k in hl.tile([B, H, T, K], block_size=[1, 1, C, block_k]): + b = tile_b.begin + h = tile_h.begin + a_mat_w = A[b, rt, h, :].to(torch.float32) # [C, C] + coeff_w = (beta[b, rt, h] * torch.exp(g[b, rt, h])).to(torch.float32) # [C] + sk = k[b, rt, h, tile_k].to(torch.float32) * coeff_w[:, None] # [C, block_k] + w_out[b, rt, h, tile_k] = hl.dot(a_mat_w, sk, out_dtype=torch.float32).to(k.dtype) + + # U = A_mat @ (β * v) + for tile_b2, tile_h2, rt2, tile_v in hl.tile([B, H, T, V], block_size=[1, 1, C, block_v]): + b2 = tile_b2.begin + h2 = tile_h2.begin + a_mat_u = A[b2, rt2, h2, :].to(torch.float32) # [C, C] + bv = beta[b2, rt2, h2].to(torch.float32) # [C] + sv = v[b2, rt2, h2, tile_v].to(torch.float32) * bv[:, None] # [C, block_v] + u_out[b2, rt2, h2, tile_v] = hl.dot(a_mat_u, sv, out_dtype=torch.float32).to(v.dtype) return w_out, u_out diff --git a/problems/helion/requirements.txt b/problems/helion/requirements.txt new file mode 100644 index 00000000..5816a883 --- /dev/null +++ b/problems/helion/requirements.txt @@ -0,0 +1 @@ +helion