Skip to content

【训练营】feat: integrate Flash Attention operator into InfiniTrain framework#119

Open
tangguochuan wants to merge 2 commits intoInfiniTensor:masterfrom
tangguochuan:feat/flash-attention
Open

【训练营】feat: integrate Flash Attention operator into InfiniTrain framework#119
tangguochuan wants to merge 2 commits intoInfiniTensor:masterfrom
tangguochuan:feat/flash-attention

Conversation

@tangguochuan
Copy link

Add a self-contained Flash Attention forward/backward implementation (BLOCK_Q=64, BLOCK_KV=64, sm_80+, bf16 only) and wire it into the autograd/dispatcher system.

Key changes:

  • infini_train/include/autograd/flash_attention.h: FlashAttention Function
  • infini_train/src/autograd/flash_attention.cc: Forward/Backward with saved tensors {Q,K,V,O,L}; L (logsumexp) passed through SetupContext
  • infini_train/src/kernels/cuda/flash_attention.cu: self-contained CUDA kernel (inlines tiling logic, MMA m16n8k16, online softmax); GQA supported (q_head != kv_head); must use framework NonBlocking stream
  • CMakeLists.txt: build flash_attention.cu as separate sm_80;90 target (infini_train_flash_attention) to avoid sm_75 compile failure
  • example/gpt2, example/llama3: add --flash flag to switch attention path

Constraints: dtype=bfloat16 only, head_dim=64 only.

Add a self-contained Flash Attention forward/backward implementation
(BLOCK_Q=64, BLOCK_KV=64, sm_80+, bf16 only) and wire it into the
autograd/dispatcher system.

Key changes:
- infini_train/include/autograd/flash_attention.h: FlashAttention Function
- infini_train/src/autograd/flash_attention.cc: Forward/Backward with
  saved tensors {Q,K,V,O,L}; L (logsumexp) passed through SetupContext
- infini_train/src/kernels/cuda/flash_attention.cu: self-contained CUDA
  kernel (inlines tiling logic, MMA m16n8k16, online softmax); GQA
  supported (q_head != kv_head); must use framework NonBlocking stream
- CMakeLists.txt: build flash_attention.cu as separate sm_80;90 target
  (infini_train_flash_attention) to avoid sm_75 compile failure
- example/gpt2, example/llama3: add --flash flag to switch attention path

Constraints: dtype=bfloat16 only, head_dim=64 only.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@kilinchange kilinchange changed the title feat: integrate Flash Attention operator into InfiniTrain framework 【训练营】feat: integrate Flash Attention operator into InfiniTrain framework Mar 16, 2026
@kilinchange kilinchange self-requested a review March 16, 2026 07:08
@kilinchange
Copy link
Collaborator

请解决当前 pr 与 master 的冲突。

@tangguochuan
Copy link
Author

我关闭了pr118. 我解决了当前pr冲突,因此当前pr产生了新的commit, 需要再开一个分支提交干净的pr吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants