Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "base"]
path = base
url = https://github.com/pytorch/helion
1 change: 1 addition & 0 deletions base
Submodule base added at 5eb889
27 changes: 27 additions & 0 deletions problems/helion/causal_conv1d_py/reference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
"""
Causal depthwise Conv1D — reference implementation.

Used in SSM-based LLM architectures such as Mamba, where a short causal
(left-padded) depthwise convolution mixes local context along the sequence
dimension independently per channel, before the selective state-space step.

"Causal" means output position t depends only on input positions <= t,
enforced by padding W-1 zeros on the left and no padding on the right.

This module provides a pure-PyTorch reference against which optimized
Triton/CUDA kernels are verified.

Shapes:
x : (B, D, S) — batch, channels (model dim), sequence length
weight : (D, W) — one filter of width W per channel (depthwise)
bias : (D,) — per-channel bias
output : (B, D, S) — same shape as input
"""
import torch
import torch.nn.functional as F
from task import input_t, output_t
from utils import make_match_reference, DeterministicContext


def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t:
"""Generate random (x, weight, bias) on CUDA with a fixed seed for reproducibility."""
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
x = torch.randn(B, D, S, dtype=torch.float32, device="cuda", generator=gen).contiguous()
Expand All @@ -14,6 +34,13 @@ def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t:


def ref_kernel(data: input_t) -> output_t:
"""
Causal depthwise Conv1D via PyTorch.

Pads W-1 zeros on the left of the sequence so that the convolution at
position t sees only x[:, :, t-W+1 : t+1], preserving causality.
groups=D makes each channel use its own filter (depthwise).
"""
with DeterministicContext():
x, weight, bias = data
B, D, S = x.shape
Expand Down
45 changes: 15 additions & 30 deletions problems/helion/causal_conv1d_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,23 @@


# Per-shape configs: map (B, D, S, W) to optimized helion.Config objects.
# Autotune locally for each shape, then paste the best config here.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 64, 64, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(2, 128, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 256, 256, 3): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 128, 64, 8): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(4, 64, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 64, 64, 4): helion.Config(block_sizes=[1, 64], num_warps=2, num_stages=2),
(2, 128, 128, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=2),
(1, 256, 256, 3): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=2),
(1, 128, 64, 8): helion.Config(block_sizes=[1, 64], num_warps=2, num_stages=2),
(4, 64, 128, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=2),
# Non-benchmarked shapes (keep for safety)
(1, 768, 512, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3),
(1, 768, 2048, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3),
# Benchmark shapes
(1, 768, 512, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 768, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 1536, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 2560, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 2560, 4096, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(1, 1536, 2048, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3),
(1, 2560, 2048, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=3),
(1, 2560, 4096, 4): helion.Config(block_sizes=[1, 128], num_warps=4, num_stages=4),
}


# Optional: add advanced_controls_file to your Config for extra performance (see docs).
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
# helion.Config(..., advanced_controls_file="/opt/booster_pack/causal_conv_0.acf")


# NOTE: This is an intentionally inefficient baseline implementation.
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, config=config)
def kernel(
Expand All @@ -46,20 +40,11 @@ def kernel(

for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]):
bi = rb.begin
acc1 = hl.zeros([rd, rs], dtype=torch.float32)
acc2 = hl.zeros([rd, rs], dtype=torch.float32)
acc3 = hl.zeros([rd, rs], dtype=torch.float32)
acc = hl.zeros([rd, rs], dtype=torch.float32)
for j in range(W):
c1 = w[rd, j].to(torch.float32)
x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc1 = acc1 + x1 * c1[:, None]
c2 = w[rd, j].to(torch.float32)
x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc2 = acc2 + x2 * c2[:, None]
c3 = w[rd, j].to(torch.float32)
x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc3 = acc3 + x3 * c3[:, None]
acc = (acc1 + acc2 + acc3) / 3.0
c = w[rd, j].to(torch.float32)
x_val = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
acc = acc + x_val * c[:, None]
acc = acc + b[rd].to(torch.float32)[:, None]
y[rb, rd, rs] = acc[None, :, :].to(y.dtype)

Expand Down
2 changes: 2 additions & 0 deletions problems/helion/fp8_quant_py/prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Goal: Maximize Helion Kernel in @submission.py; Ensure correctness using @reference.py. Details: Use @../../../base/examples/ that contains optimized helion examples.
@../../helion/, use "python eval.py test causal_conv1d_py/" to test correctness and "python eval.py benchmark causal_conv1d_py/" to measure performance.
8 changes: 6 additions & 2 deletions problems/helion/fp8_quant_py/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@ def ref_kernel(data: input_t) -> output_t:
num_groups = x_s.shape[1]
group_size = hidden_dim // num_groups

# convert to float32 for computation
x_f32 = x.float()
# reshape into fp8 groups
x_grouped = x_f32.reshape(num_tokens, num_groups, group_size)

# Per-group absmax
# Per-group absmax and clamp to mimimum fp8 value
absmax = x_grouped.abs().amax(dim=-1).clamp(min=FP8_EPS)

# Scale = absmax / fp8_max
# scale abs-max by maximum fp8
scale = absmax / FP8_MAX

# Quantize
# Quantize by dividing by scale and clamping to fp8 range
quantized = (x_grouped / scale.unsqueeze(-1)).clamp(FP8_MIN, FP8_MAX)
# Reshape to original shape
quantized = quantized.reshape(num_tokens, hidden_dim)

x_q[...] = quantized
Expand Down
62 changes: 27 additions & 35 deletions problems/helion/fp8_quant_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,79 +10,71 @@
# Autotune locally for each shape, then paste the best config here.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(4, 512, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(16, 1024, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(8, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
(1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1),
(4, 512, 128): helion.Config(block_sizes=[2], num_warps=2, num_stages=1),
(16, 1024, 64): helion.Config(block_sizes=[4], num_warps=2, num_stages=1),
(1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1),
(8, 4096, 128): helion.Config(block_sizes=[4], num_warps=2, num_stages=1),
# Benchmark shapes
# (1, 4096, 128) already covered above
(16, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(256, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(256, 8192, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(4096, 7168, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
(16, 4096, 128): helion.Config(block_sizes=[4], num_warps=4, num_stages=2),
(256, 4096, 128): helion.Config(block_sizes=[8], num_warps=8, num_stages=2),
(256, 8192, 128): helion.Config(block_sizes=[8], num_warps=8, num_stages=2),
(4096, 7168, 128):helion.Config(block_sizes=[16], num_warps=8, num_stages=2),
}


# Optional: add advanced_controls_file to your Config for extra performance (see docs).
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
# helion.Config(..., advanced_controls_file="/opt/booster_pack/fp8_group_quant_0.acf")


# NOTE: This is an intentionally inefficient baseline implementation.
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, config=config)
def kernel(
data: torch.Tensor, # [N, G] input rows
scales_out: torch.Tensor, # [N] output normalization factors
data: torch.Tensor, # [N, G] float32
scales_out: torch.Tensor, # [N]
) -> torch.Tensor:
nrows = data.size(0)
ncols = hl.specialize(data.size(1))
MAX_VAL = 448.0

qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device)

for rr in hl.tile(nrows):
row = data[rr, :].to(torch.float32)

abs1 = torch.abs(row)
amax1 = torch.amax(abs1, -1)
abs2 = torch.abs(row)
amax2 = torch.amax(abs2, -1)
abs3 = torch.abs(row)
amax3 = torch.amax(abs3, -1)
amax = (amax1 + amax2 + amax3) / 3.0
amax = torch.clamp(amax, min=1e-10)
amax = torch.amax(torch.abs(row), dim=-1).clamp(min=1e-10)
scale = amax / MAX_VAL

q1 = row / scale[:, None]
q2 = row / scale[:, None]
q3 = row / scale[:, None]
qout[rr, :] = (q1 + q2 + q3) / 3.0
qout[rr, :] = torch.clamp(row / scale[:, None], -MAX_VAL, MAX_VAL)
scales_out[rr] = scale

return qout

return kernel


# Create a kernel for each shape based on the configs above.
_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}


# TODO Fuse reshapes into kernel. Reshape is a no-op for the GPU, but requires
# copy kernel in pytorch eager mode.

def custom_kernel(data: input_t) -> output_t:
x, x_q, x_s = data
# num_tokens, hidden_dim = x.shape
T, H = x.shape
# num_groups = x_s.shape[1]
G = x_s.shape[1]
# group_size = hidden_dim // num_groups
gsz = H // G
# merge num_tokens and num_groups into a single dimension for the kernel
N = T * G

# Select the appropriate kernel based on the input shape.
kernel = _KERNELS[(T, H, gsz)]

# Reshape inputs and scales for the kernel
# inputs is input argument
# scales is output argument, but we need to pass it in as an argument
flat_in = x.reshape(N, gsz)
flat_s = x_s.reshape(N)

# return value is quantized output
flat_q = kernel(flat_in, flat_s)

# Reshape outputs back to original shapes
x_q[...] = flat_q.reshape(T, H)
x_s[...] = flat_s.reshape(T, G)
return x_q, x_s
48 changes: 24 additions & 24 deletions problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@
import helion.language as hl


# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects.
# Autotune locally for each shape, then paste the best config here.
# Per-shape configs: map (B, T, H, K, V) to helion.Config objects.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check
(2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check
(1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check
(1, 64, 2, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2),
(2, 128, 4, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2),
(1, 256, 4, 64, 128): helion.Config(block_sizes=[64], num_warps=8, num_stages=2),
# Benchmark shapes
(1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
(1, 64, 1, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2),
(2, 512, 3, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2),
(2, 1024, 3, 64, 64): helion.Config(block_sizes=[64], num_warps=4, num_stages=2),
(3, 1024, 4, 100, 100): helion.Config(block_sizes=[64], num_warps=8, num_stages=3),
(4, 1024, 4, 128, 128): helion.Config(block_sizes=[64], num_warps=8, num_stages=3),
(2, 1536, 4, 128, 128): helion.Config(block_sizes=[64], num_warps=8, num_stages=3),
(4, 2048, 8, 64, 64): helion.Config(block_sizes=[64], num_warps=8, num_stages=3),
}


Expand All @@ -28,7 +27,6 @@
# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_o_0.acf")


# NOTE: This is an intentionally inefficient baseline implementation.
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, dot_precision="ieee", config=config)
def kernel(
Expand All @@ -53,24 +51,26 @@ def kernel(
h_idx = flat_bh.begin % H
c_idx = tile_t.begin // C

g_vals = g[b_idx, tile_t, h_idx]
q_tile = q[b_idx, tile_t, h_idx, :]
k_tile = k[b_idx, tile_t, h_idx, :]
v_tile = v[b_idx, tile_t, h_idx, :]
g_tile = g[b_idx, tile_t, h_idx].to(torch.float32)
q_tile = q[b_idx, tile_t, h_idx, :].to(torch.float32)
k_tile = k[b_idx, tile_t, h_idx, :].to(torch.float32)

# intra-chunk: q @ k^T * exp(g_i - g_j), with causal mask
qk = hl.dot(q_tile, k_tile.T)
qk = hl.dot(q_tile, k_tile.T, out_dtype=torch.float32)
idx = hl.arange(tile_t.block_size)
g_diff = g_vals[:, None] - g_vals[None, :]
g_diff = g_tile[:, None] - g_tile[None, :]
causal_mask = idx[:, None] >= idx[None, :]
sim = torch.where(causal_mask, qk * torch.exp(g_diff), 0.0)
local_out = hl.dot(sim.to(v.dtype), v_tile)

# inter-chunk: (q @ h) * exp(g)
q_s = q_tile * torch.exp(g_vals)[:, None]
global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :])

out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype)
q_s = q_tile * torch.exp(g_tile)[:, None]

for tv in hl.tile(V):
v_tile = v[b_idx, tile_t, h_idx, tv].to(torch.float32)
h_tile = h[b_idx, c_idx, h_idx, :, tv].to(torch.float32)
local_out = hl.dot(sim, v_tile, out_dtype=torch.float32)
global_out = hl.dot(q_s, h_tile, out_dtype=torch.float32)
out[b_idx, tile_t, h_idx, tv] = ((global_out + local_out) * scale).to(out.dtype)

return out

Expand Down
1 change: 1 addition & 0 deletions problems/helion/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
helion