Skip to content
Closed
9 changes: 8 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
DEFINE_bool(flash, false, "Whether to enable FlashAttention in CausalSelfAttention");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
Expand Down Expand Up @@ -180,12 +181,18 @@ void Train(const nn::parallel::Rank &rank) {
// init the model, either from scratch or from OpenAI pretrained checkpoint
GPT2Config model_config;
std::shared_ptr<nn::Module> model = nullptr;
LOG(INFO) << "Rank " << rank.GlobalRank() << ": FLAGS_flash = " << (FLAGS_flash ? "true" : "false");
if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Loading GPT2 from LLMC file: " << FLAGS_llmc_filepath;
model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.flash = FLAGS_flash;
model = std::make_shared<GPT2>(model_config);
} else {
if (FLAGS_flash) {
LOG(WARNING) << "--flash is ignored when loading GPT2 from pretrained checkpoint.";
}
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
}

Expand Down
43 changes: 29 additions & 14 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "example/common/utils.h"
#include "infini_train/include/device.h"
#include "infini_train/include/autograd/ScaledDotProductAttention.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
#include "infini_train/include/nn/modules/container.h"
Expand Down Expand Up @@ -105,18 +106,31 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
std::shared_ptr<Tensor> y = nullptr;
if (config_.flash) {
// FlashAttention expects (B, T, H, D)
auto q_flash = q->Transpose(1, 2);
auto k_flash = k->Transpose(1, 2);
auto v_flash = v->Transpose(1, 2);
auto y_flash = std::make_shared<autograd::ScaledDotProductAttention>(
/*attn_mask=*/nullptr, /*dropout_p=*/0, /*is_causal=*/true,
/*scale=*/1.0 / std::sqrt(static_cast<double>(head_dim)), /*enable_gqa=*/false)
->Apply({q_flash, k_flash, v_flash})[0];
y = y_flash->Contiguous()->View({B, T, local_C});
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
}

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
Expand Down Expand Up @@ -351,7 +365,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
}
} // namespace

std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -379,7 +393,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
.original_vocab_size = vocab_size,
.n_layer = n_layer,
.n_head = n_head,
.n_embd = n_embd});
.n_embd = n_embd,
.flash = flash});

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
Expand Down
3 changes: 2 additions & 1 deletion example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool flash = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down Expand Up @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const;

Expand Down
6 changes: 5 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
DEFINE_bool(flash, false, "Whether to enable FlashAttention in CausalSelfAttention");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -161,9 +162,12 @@ void Train(const nn::parallel::Rank &rank) {
// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
model_config.flash = FLAGS_flash;
std::shared_ptr<nn::Module> model = nullptr;
LOG(INFO) << "Rank " << rank.GlobalRank() << ": FLAGS_flash = " << (FLAGS_flash ? "true" : "false");
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Loading LLaMA3 from LLMC file: " << FLAGS_llmc_filepath;
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model = std::make_shared<LLaMA3>(model_config);
}
Expand Down
64 changes: 34 additions & 30 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/ScaledDotProductAttention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -207,36 +208,38 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
// TODO(zbl): use kv cache during inference
// if (use_kv_) { ... }

// align n_head in GQA
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
k = RepeatKV(k, n_rep_);
v = RepeatKV(v, n_rep_);

// (B, T, H_local, D) -> (B, H_local, T, D)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y = nullptr;
if (config_.flash) {
auto y_flash = std::make_shared<autograd::ScaledDotProductAttention>(
/*attn_mask=*/nullptr, /*dropout_p=*/0, /*is_causal=*/true,
/*scale=*/1.0 / std::sqrt(static_cast<double>(D)), /*enable_gqa=*/true)
->Apply({q, k, v})[0];
y = y_flash->Contiguous()->View({B, T, C_local});
} else {
// align n_head in GQA
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
k = RepeatKV(k, n_rep_);
v = RepeatKV(v, n_rep_);

// (B, T, H_local, D) -> (B, H_local, T, D)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
// output projection
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
Expand Down Expand Up @@ -457,7 +460,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -496,6 +499,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.rope_theta = rope_theta,
.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.flash = flash,
.max_gen_batch_size = max_gen_bs});

// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
Expand Down
4 changes: 2 additions & 2 deletions example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct LLaMA3Config {

// Inference
bool use_kv = false; // kv cache
bool flash = false; // flash attention
bool flash = false; // enable flash attention path in CausalSelfAttention
int64_t max_gen_batch_size = 4; // max batch size during inference
};

Expand Down Expand Up @@ -179,7 +179,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
42 changes: 42 additions & 0 deletions infini_train/include/autograd/ScaledDotProductAttention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include "infini_train/include/autograd/function.h"
#include "infini_train/include/kernels/cuda/flash_attention.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {
class ScaledDotProductAttention : public Function {
public:
static constexpr char kType[] = "ScaledDotProductAttentionFunction";

ScaledDotProductAttention(std::shared_ptr<Tensor> attn_mask = nullptr, int64_t dropout_p = 0,
bool is_causal = false, std::optional<double> scale = std::nullopt,
bool enable_gqa = false)
: Function(kType), attn_mask_(std::move(attn_mask)), dropout_p_(dropout_p), is_causal_(is_causal),
scale_(scale), enable_gqa_(enable_gqa) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
std::shared_ptr<Tensor> attn_mask_;
int64_t dropout_p_ = 0;
bool is_causal_ = false;
std::optional<double> scale_;
bool enable_gqa_ = false;

// Temporary storage for FlashAttentionForwardOutput to be used in SetupContext
// Note: This is defined in infini_train::kernels::cuda namespace
kernels::cuda::FlashAttentionForwardOutput flash_output_;
};
} // namespace infini_train::autograd
27 changes: 27 additions & 0 deletions infini_train/include/kernels/cuda/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include <memory>

namespace infini_train {
class Tensor;
}

namespace infini_train::kernels::cuda {

/**
* FlashAttention Forward Output Structure
*
* This structure holds the output tensors from FlashAttention forward pass.
*
* Args:
* output: Output tensor of shape [batch_size, seq_len_q, num_heads, head_dim]
* logsumexp: Logsumexp tensor for backward pass [batch_size, num_heads, seq_len_q]
* dropout_seed: Dropout seed for backward pass [1]
*/
struct FlashAttentionForwardOutput {
std::shared_ptr<Tensor> output;
std::shared_ptr<Tensor> logsumexp;
std::shared_ptr<Tensor> dropout_seed;
};

} // namespace infini_train::kernels::cuda
Loading
Loading