Skip to content

【训练营】FlashAttention 接入#125

Open
BAI-123-GUO wants to merge 2 commits intoInfiniTensor:masterfrom
BAI-123-GUO:master
Open

【训练营】FlashAttention 接入#125
BAI-123-GUO wants to merge 2 commits intoInfiniTensor:masterfrom
BAI-123-GUO:master

Conversation

@BAI-123-GUO
Copy link

Summary

  • 接入基于 cuDNN Frontend Graph 的 FlashAttention/SDPA,实现 GPT-2 与 LLaMA-3 的 --flash 开关,并补齐 functional / autograd / CUDA kernel 路径。
  • 在 A100 + CUDA 12.8 + cuDNN 9.7 环境完成 BF16 验证与 benchmark;Flash 路径按硬件/后端支持边界仅面向 CUDA + BF16。
  • 补充 benchmark 与日志解析脚本,便于复现实验与生成性能结果。

Main Changes

  • example/gpt2/main.cc
  • example/gpt2/net.cc
  • example/llama3/main.cc
  • example/llama3/net.cc
  • infini_train/include/nn/functional.h
  • infini_train/src/nn/functional.cc
  • infini_train/include/autograd/scaled_dot_product_attention.h
  • infini_train/src/autograd/scaled_dot_product_attention.cc
  • infini_train/src/kernels/cuda/scaled_dot_product_attention.cu
  • scripts/flash_sdpa_benchmark.bash
  • scripts/flash_sdpa_parse.py

Validation

  • GPU: NVIDIA A100-SXM4-80GB
  • CUDA: 12.8
  • cuDNN: 9.7.0

benchmark
GPT-2: 238.37 ms -> 203.90 ms (1.169x)
LLaMA-3: 1299.91 ms -> 1278.74 ms (1.017x)

Notes
报告与日志未包含在 PR 中,将单独提交。

@kilinchange kilinchange self-requested a review March 17, 2026 06:20
@kilinchange kilinchange self-assigned this Mar 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants