Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
7a16589
refactor(optimizer): hoist learning_rate_ to Optimizer base and add l…
Mar 3, 2026
0514862
refactor(distributed_optimizer): passthrough SetLearningRate/GetLearn…
Mar 3, 2026
81295e8
feat(lr_scheduler): add LRScheduler abstract base class with StateDict
Mar 3, 2026
8e7cda0
refactor(examples): add scheduler placeholder and use runtime lr in logs
Mar 3, 2026
1e65881
feat: add ConstantLR, StepLR and LinearWarmupLR
Mar 4, 2026
d924d3d
refactor(lr_scheduler): replace ComputeLR with virtual Step and Apply…
Mar 5, 2026
baca2ef
feat(lr_schedulers): add LambdaLR strategy
Mar 5, 2026
7df75d7
refactor(optimizer): add initial_learning_rate and it's accessors
Mar 5, 2026
d0ac538
feat(lr_schedulers): add SequentialLR composite strategy
Mar 5, 2026
df4c68d
refactor(lr_scheduler): apply template method pattern to LRScheduler …
Mar 5, 2026
5b4ef6d
feat(lr_scheduler): add factory method Create<T>() with two-phase ini…
Mar 5, 2026
8c11dd9
feat(lr_scheduler): add closed and chained form, adjust LinearLR、Sequ…
Mar 6, 2026
6823244
feat(lr_schedulers): add ChainedScheduler composite strategy
Mar 6, 2026
fb9d997
feat(lr_scheduler): add scheduler factory for CLI integration
Mar 8, 2026
7a29a61
feat(lr_scheduler): add scheduler factory for CLI integration (Sequen…
Mar 8, 2026
b64566e
feat(lr_scheduler): add warmup start_factor and end_factor , remove c…
Mar 8, 2026
3a7abb4
refactor(gpt2,llama3): integrate scheduler into training loop
Mar 8, 2026
f7b3fcb
Merge branch 'InfiniTensor:master' into lr_scheduler
littleotherut Mar 11, 2026
1f95e29
style: apply clang-format to all legacy files
littleotherut Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,24 @@ target_link_libraries(test_hook infini_train)

add_executable(test_precision_check test/hook/test_precision_check.cc)
target_link_libraries(test_precision_check infini_train)

add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc)
target_link_libraries(test_lr_scheduler infini_train)

add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc)
target_link_libraries(test_constant_lr infini_train)

add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc)
target_link_libraries(test_step_lr infini_train)

add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc)
target_link_libraries(test_linear_lr infini_train)

add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc)
target_link_libraries(test_lambda_lr infini_train)

add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc)
target_link_libraries(test_sequential_lr infini_train)

add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
target_link_libraries(test_chained_lr infini_train)
37 changes: 34 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
Expand Down Expand Up @@ -55,6 +56,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay");
DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay");
DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor");
DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor");
DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -268,6 +279,20 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(model->Parameters());
}

LRSchedulerConfig sched_config;
sched_config.type = FLAGS_lr_scheduler;
sched_config.warmup_steps = FLAGS_warmup_steps;
sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor);
sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor);
sched_config.step_size = FLAGS_step_size;
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
sched_config.constant_total_iters = FLAGS_lr_total_iters;
sched_config.linear_total_iters = FLAGS_lr_total_iters;
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand Down Expand Up @@ -354,6 +379,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -363,6 +391,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -378,11 +409,11 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scheduler 在前面已经 Step 过了,所以这里 GetLR() 语义上是”下一步要用到的 lr“;而我们这里想打印的是每一步实际用到的 lr,所以这里的逻辑需要修改下。llama3 部分的 main.cc 里同理。

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
37 changes: 34 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
Expand Down Expand Up @@ -54,6 +55,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay");
DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay");
DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor");
DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor");
DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -247,6 +258,20 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(model->Parameters());
}

LRSchedulerConfig sched_config;
sched_config.type = FLAGS_lr_scheduler;
sched_config.warmup_steps = FLAGS_warmup_steps;
sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor);
sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor);
sched_config.step_size = FLAGS_step_size;
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
sched_config.constant_total_iters = FLAGS_lr_total_iters;
sched_config.linear_total_iters = FLAGS_lr_total_iters;
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
Expand Down Expand Up @@ -330,6 +355,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -339,6 +367,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -354,11 +385,11 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
186 changes: 186 additions & 0 deletions infini_train/include/lr_scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#pragma once

#include <cmath>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>

namespace infini_train {

class Optimizer;

using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>;
using StateDict = std::unordered_map<std::string, StateValue>;

struct LRSchedulerConfig {
std::string type = "none";
// ConstantLR
float constant_factor = 1.0f / 3.0f;
int constant_total_iters = 5;
// StepLR
int64_t step_size = 10;
float step_gamma = 0.1f;
// LinearLR
float linear_start_factor = 1.0f / 3.0f;
float linear_end_factor = 1.0f;
int linear_total_iters = 5;
// LambdaLR
std::function<float(int64_t)> lambda_fn = nullptr;
// SequentialLR
std::vector<LRSchedulerConfig> sequential_configs;
std::vector<int64_t> sequential_milestones;
// ChainedScheduler
std::vector<LRSchedulerConfig> chained_configs;
// warmup
int64_t warmup_steps = 0;
float warmup_start_factor = 1.0f / 3.0f;
float warmup_end_factor = 1.0f;
};

class LRScheduler {
public:
template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) {
auto scheduler = std::make_shared<T>(std::forward<Args>(args)...);
scheduler->InitialStep();
return scheduler;
}

explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1);
virtual ~LRScheduler() = default;

LRScheduler(const LRScheduler &) = delete;
LRScheduler &operator=(const LRScheduler &) = delete;

virtual void Step();
virtual void Step(int64_t epoch);
virtual void InitialStep();

float GetLR() const;
float BaseLR() const;
int64_t LastStep() const;

void ResetStep(int64_t step = -1);
virtual StateDict State() const;
virtual void LoadState(const StateDict &state);

protected:
virtual float GetClosedFormLR() const = 0;
virtual float GetChainedFormLR() const;
void ApplyLR(float lr);

std::shared_ptr<Optimizer> optimizer_;
int64_t last_step_;
float current_lr_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_lr_ 似乎也有点冗余,语义上 current_lr_ 和 optimizer_->GetLearningRate() 的值在任何时候应等价,现在在你的设计里看到这二者存在各自分开存且混用的状态(读完发现目前的 current_lr_ 像是 optimizer_->GetLearningRate() 的一个副本);目前的数值正确性上你处理的没问题,但是这种设计交给后人来扩展的时候很可能带来歧义。

建议针对“当前学习率”只保留唯一真状态来源,要么就全程由 optimizer_->GetLearningRate() 跟踪,lr scheduler 里面就不存 current lr 了;要么就由 lr scheduler 跟踪,每次计算完再 set 回 optimizer。个人认为前者较合适。

float base_lr_;
bool is_initial_ = false;
};

std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config);

namespace lr_schedulers {

class ConstantLR : public LRScheduler {
public:
ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5,
int64_t last_step = -1);
~ConstantLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const float factor_;
const int64_t total_iters_;
};

class StepLR : public LRScheduler {
public:
StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1);
~StepLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const int64_t step_size_;
const float gamma_;
};

class LinearLR : public LRScheduler {
public:
LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f,
int64_t total_iters = 5, int64_t last_step = -1);
~LinearLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const float start_factor_;
const float end_factor_;
const int64_t total_iters_;
};

class LambdaLR : public LRScheduler {
public:
using LambdaFunc = std::function<float(int64_t)>;

LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1);
~LambdaLR() override = default;

protected:
float GetClosedFormLR() const override;

private:
const LambdaFunc lr_lambda_;
};

class SequentialLR : public LRScheduler {
public:
SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
std::vector<int64_t> milestones, int64_t last_step = -1);
~SequentialLR() override = default;

void Step() override;
void InitialStep() override;

StateDict State() const override;
void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override { return current_lr_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块语义上不太对,我仔细看了下 torch 里面的实现,GetClosedFormLR 对标 torch 里提供的 get_closed_form_lr 的接口的话, 实际是想实现一个“给定 base_lr、last_step 以及其他超参,然后可以通过公式算出当前 lr 的 function”。这个虽然数值上确实等于你现在提供的 current_lr,但是逻辑上的代码不应该直接返回缓存的 current_lr_ 就完事,而是应该给一个计算公式。

另外,torch 里提供的 _get_closed_form_lr 的接口,最终实际上是用于 step(int epoch) 这个 function 的,如果对应的 LRScheduler 派生类实现了这个 _get_closed_form_lr,就代表其支持 closed form 的跳步语义,然后 step(epoch) 会直接由提供的 function 计算出 current lr。而 torch 里面的 SequentialLR 派生类没有实现这个 function。

考虑到你这边的 GetClosedFormLR 定义为虚函数,要求所有派生类必须实现,我建议是在这里加上一个 // FIXME 的注释说明一下这一点,目前暂时先以返回一个 current lr 来 hack 实现,而不是提供了 closed-form 计算方法。

void UndoChildInitialSteps();

private:
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
std::vector<int64_t> milestones_;
};

class ChainedScheduler : public LRScheduler {
public:
ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
int64_t last_step = -1);
~ChainedScheduler() override = default;

void Step() override;
void InitialStep() override;

StateDict State() const override;
void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override { return current_lr_; }

private:
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
};

} // namespace lr_schedulers
} // namespace infini_train
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format 规范上,end of file 需要有一个 newline,后续也有几个文件存在这个问题

3 changes: 3 additions & 0 deletions infini_train/include/nn/parallel/ddp/distributed_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer {
void StartParamSync(bool force_sync = false);
void FinishParamSync(bool skip_next_bucket_dispatch = false);

virtual void SetLearningRate(float lr) override;
virtual float GetLearningRate() const override;

private:
void BuildShardParamsAndBindGrads();

Expand Down
Loading
Loading