Skip to content

【训练营】学习率调度器实现#113

Open
littleotherut wants to merge 19 commits intoInfiniTensor:masterfrom
littleotherut:lr_scheduler
Open

【训练营】学习率调度器实现#113
littleotherut wants to merge 19 commits intoInfiniTensor:masterfrom
littleotherut:lr_scheduler

Conversation

@littleotherut
Copy link

No description provided.

kinorw and others added 19 commits March 3, 2026 14:42
…base class

- Change Step() to virtual with default implementation
- Add pure virtual ComputeLR() for subclasses to implement.
- Adapt test helpers (IdentityScheduler, LinearDecayScheduler) to  implement ComputeLR() instead of Step().
- All existing tests pass without behavioral changes.

BREAKING CHANGE: Subclasses must implement ComputeLR() instead of Step().
…t and update all tests to use Create<T>() factory method.
…entialLR

- enhance LRScheduler with chained and closed form learning rate methods
- adapt methods(Step, InitialStep, GetClosedFormLR, GetChainedFormLR) to match PyTorch‘s design
- add tests for consistency
- refactor LinearLR: add end_factor, and rename this class
- add SequentialLR InitialStep and UndoChildInitialSteps

BREAKING CHANGE: Subclasses must implement GetClosedFormLR instead of ComputeLR(). Should use LinearLR instead of LinearwarmupLR.
- Add LRSchedulerConfig struct with parameters for all basic schedulers(constant, linear, step)
- Add CreateLRScheduler() factory function
- Support automatic warmup wrapping via SequentialLR when warmup_steps > 0
- Adapt test files
- Add gflags: --lr_scheduler, --warmup_steps, --step_size, --gamma, --start_factor, --end_factor, --lr_total_iters, --total_steps
- Replace nullptr scheduler with factory-created scheduler
- Move scheduler.Step() after optimizer.Step() in both DP and PP paths
- Replace hardcoded FLAGS_learning_rate in log with scheduler->GetLR()
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scheduler 在前面已经 Step 过了,所以这里 GetLR() 语义上是”下一步要用到的 lr“;而我们这里想打印的是每一步实际用到的 lr,所以这里的逻辑需要修改下。llama3 部分的 main.cc 里同理。

size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理

std::vector<std::shared_ptr<Tensor>> params_;
float learning_rate_ = 0.0f;
float initial_learning_rate_ = 0.0f;
bool initial_lr_set_ = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分比较冗余。optimizer 里面可以只存有代表当前学习率的 learning_rate_,不需要额外存 initial lr 的状态;语义上初始学习率可以仅存在 lr scheduler 里(你是实际上已经这样做了,存在 lr scheduler 的 base_lr)。


std::shared_ptr<Optimizer> optimizer_;
int64_t last_step_;
float current_lr_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_lr_ 似乎也有点冗余,语义上 current_lr_ 和 optimizer_->GetLearningRate() 的值在任何时候应等价,现在在你的设计里看到这二者存在各自分开存且混用的状态(读完发现目前的 current_lr_ 像是 optimizer_->GetLearningRate() 的一个副本);目前的数值正确性上你处理的没问题,但是这种设计交给后人来扩展的时候很可能带来歧义。

建议针对“当前学习率”只保留唯一真状态来源,要么就全程由 optimizer_->GetLearningRate() 跟踪,lr scheduler 里面就不存 current lr 了;要么就由 lr scheduler 跟踪,每次计算完再 set 回 optimizer。个人认为前者较合适。


void LRScheduler::ApplyLR(float lr) {
current_lr_ = lr;
optimizer_->SetLearningRate(current_lr_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

承接上面所说的,在你的设计中一方面看到有 optimizer_->SetLearningRate(current_lr_); 这种调用,另一方面又有 current_lr_ = optimizer_->GetLearningRate();,二者可能会存在谁因谁果的混淆,所以建议保持设计上语义的一致性。

scheduler->Step();
}

current_lr_ = optimizer_->GetLearningRate();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

承接上面所说的,在你的设计中一方面看到有 optimizer_->SetLearningRate(current_lr_); 这种调用,另一方面又有 current_lr_ = optimizer_->GetLearningRate();,二者可能会存在谁因谁果的混淆,所以建议保持设计上语义的一致性。

} else if (last_step_ < total_iters_) {
return lr;
} else if (last_step_ == total_iters_) {
return lr / factor_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

个别超参的值由于是由 cli 用户传入,所以需要加一下非法检查。以此处为例,factor 应该是 (0, 1) 范围内的,不然可能会存在除零的非法值。torch 实现中也在构造函数中做了检查,参考:https://github.com/pytorch/pytorch/blob/08840d08a02eead8edf22406a53e5691c9a89c9a/torch/optim/lr_scheduler.py#L813

另外,以我看到的,还有 StepLR 没检查 step_size > 0,LinearLR 没检查两个 factor 以及 total_iters 等。建议通篇 check 一下。

void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override { return current_lr_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块语义上不太对,我仔细看了下 torch 里面的实现,GetClosedFormLR 对标 torch 里提供的 get_closed_form_lr 的接口的话, 实际是想实现一个“给定 base_lr、last_step 以及其他超参,然后可以通过公式算出当前 lr 的 function”。这个虽然数值上确实等于你现在提供的 current_lr,但是逻辑上的代码不应该直接返回缓存的 current_lr_ 就完事,而是应该给一个计算公式。

另外,torch 里提供的 _get_closed_form_lr 的接口,最终实际上是用于 step(int epoch) 这个 function 的,如果对应的 LRScheduler 派生类实现了这个 _get_closed_form_lr,就代表其支持 closed form 的跳步语义,然后 step(epoch) 会直接由提供的 function 计算出 current lr。而 torch 里面的 SequentialLR 派生类没有实现这个 function。

考虑到你这边的 GetClosedFormLR 定义为虚函数,要求所有派生类必须实现,我建议是在这里加上一个 // FIXME 的注释说明一下这一点,目前暂时先以返回一个 current lr 来 hack 实现,而不是提供了 closed-form 计算方法。

};

} // namespace lr_schedulers
} // namespace infini_train No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format 规范上,end of file 需要有一个 newline,后续也有几个文件存在这个问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants