Skip to content

【训练营】Add FlashAttention operator into infiniTrain Framework#124

Open
Aoshine999 wants to merge 3 commits intoInfiniTensor:masterfrom
Aoshine999:flashattention
Open

【训练营】Add FlashAttention operator into infiniTrain Framework#124
Aoshine999 wants to merge 3 commits intoInfiniTensor:masterfrom
Aoshine999:flashattention

Conversation

@Aoshine999
Copy link

Add Flash Attention forward/backward implementation and wire it into the autograd/dispatcher system.
Key changes:

  • infini_train/include/autograd/ScaledDotProductAttention.h
  • infini_train/src/autograd/ScaledDotProductAttention.cc
  • infini_train/include/kernels/cuda/flash_attention.h
  • infini_train/src/kernels/cuda/flash_attention.cu
  • run gpt2/llama3 : add --flash flag to switch attention path
    Constraints: dtype=float32, bfloat16 Flashattention forward and backward kernel only support BlockDim(32,32)

@kilinchange kilinchange changed the title Add FlashAttention operator into infiniTrain Framework 【训练营】Add FlashAttention operator into infiniTrain Framework Mar 17, 2026
@kilinchange
Copy link
Collaborator

请解决当前 pr 与 master 的冲突。

@kilinchange kilinchange self-requested a review March 17, 2026 06:21
@kilinchange kilinchange self-assigned this Mar 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants