Conversation
…base class - Change Step() to virtual with default implementation - Add pure virtual ComputeLR() for subclasses to implement. - Adapt test helpers (IdentityScheduler, LinearDecayScheduler) to implement ComputeLR() instead of Step(). - All existing tests pass without behavioral changes. BREAKING CHANGE: Subclasses must implement ComputeLR() instead of Step().
…t and update all tests to use Create<T>() factory method.
…entialLR - enhance LRScheduler with chained and closed form learning rate methods - adapt methods(Step, InitialStep, GetClosedFormLR, GetChainedFormLR) to match PyTorch‘s design - add tests for consistency - refactor LinearLR: add end_factor, and rename this class - add SequentialLR InitialStep and UndoChildInitialSteps BREAKING CHANGE: Subclasses must implement GetClosedFormLR instead of ComputeLR(). Should use LinearLR instead of LinearwarmupLR.
- Add LRSchedulerConfig struct with parameters for all basic schedulers(constant, linear, step) - Add CreateLRScheduler() factory function - Support automatic warmup wrapping via SequentialLR when warmup_steps > 0 - Adapt test files
…tial, Chained, and Lambda)
…ommon total_iters
- Add gflags: --lr_scheduler, --warmup_steps, --step_size, --gamma, --start_factor, --end_factor, --lr_total_iters, --total_steps - Replace nullptr scheduler with factory-created scheduler - Move scheduler.Step() after optimizer.Step() in both DP and PP paths - Replace hardcoded FLAGS_learning_rate in log with scheduler->GetLR()
| size_t used_mb = 0, reserved_mb = 0; | ||
| std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); | ||
|
|
||
| const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate); |
There was a problem hiding this comment.
scheduler 在前面已经 Step 过了,所以这里 GetLR() 语义上是”下一步要用到的 lr“;而我们这里想打印的是每一步实际用到的 lr,所以这里的逻辑需要修改下。llama3 部分的 main.cc 里同理。
| size_t used_mb = 0, reserved_mb = 0; | ||
| std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); | ||
|
|
||
| const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate); |
| std::vector<std::shared_ptr<Tensor>> params_; | ||
| float learning_rate_ = 0.0f; | ||
| float initial_learning_rate_ = 0.0f; | ||
| bool initial_lr_set_ = false; |
There was a problem hiding this comment.
这部分比较冗余。optimizer 里面可以只存有代表当前学习率的 learning_rate_,不需要额外存 initial lr 的状态;语义上初始学习率可以仅存在 lr scheduler 里(你是实际上已经这样做了,存在 lr scheduler 的 base_lr)。
|
|
||
| std::shared_ptr<Optimizer> optimizer_; | ||
| int64_t last_step_; | ||
| float current_lr_; |
There was a problem hiding this comment.
current_lr_ 似乎也有点冗余,语义上 current_lr_ 和 optimizer_->GetLearningRate() 的值在任何时候应等价,现在在你的设计里看到这二者存在各自分开存且混用的状态(读完发现目前的 current_lr_ 像是 optimizer_->GetLearningRate() 的一个副本);目前的数值正确性上你处理的没问题,但是这种设计交给后人来扩展的时候很可能带来歧义。
建议针对“当前学习率”只保留唯一真状态来源,要么就全程由 optimizer_->GetLearningRate() 跟踪,lr scheduler 里面就不存 current lr 了;要么就由 lr scheduler 跟踪,每次计算完再 set 回 optimizer。个人认为前者较合适。
|
|
||
| void LRScheduler::ApplyLR(float lr) { | ||
| current_lr_ = lr; | ||
| optimizer_->SetLearningRate(current_lr_); |
There was a problem hiding this comment.
承接上面所说的,在你的设计中一方面看到有 optimizer_->SetLearningRate(current_lr_); 这种调用,另一方面又有 current_lr_ = optimizer_->GetLearningRate();,二者可能会存在谁因谁果的混淆,所以建议保持设计上语义的一致性。
| scheduler->Step(); | ||
| } | ||
|
|
||
| current_lr_ = optimizer_->GetLearningRate(); |
There was a problem hiding this comment.
承接上面所说的,在你的设计中一方面看到有 optimizer_->SetLearningRate(current_lr_); 这种调用,另一方面又有 current_lr_ = optimizer_->GetLearningRate();,二者可能会存在谁因谁果的混淆,所以建议保持设计上语义的一致性。
| } else if (last_step_ < total_iters_) { | ||
| return lr; | ||
| } else if (last_step_ == total_iters_) { | ||
| return lr / factor_; |
There was a problem hiding this comment.
个别超参的值由于是由 cli 用户传入,所以需要加一下非法检查。以此处为例,factor 应该是 (0, 1) 范围内的,不然可能会存在除零的非法值。torch 实现中也在构造函数中做了检查,参考:https://github.com/pytorch/pytorch/blob/08840d08a02eead8edf22406a53e5691c9a89c9a/torch/optim/lr_scheduler.py#L813
另外,以我看到的,还有 StepLR 没检查 step_size > 0,LinearLR 没检查两个 factor 以及 total_iters 等。建议通篇 check 一下。
| void LoadState(const StateDict &state) override; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override { return current_lr_; } |
There was a problem hiding this comment.
这块语义上不太对,我仔细看了下 torch 里面的实现,GetClosedFormLR 对标 torch 里提供的 get_closed_form_lr 的接口的话, 实际是想实现一个“给定 base_lr、last_step 以及其他超参,然后可以通过公式算出当前 lr 的 function”。这个虽然数值上确实等于你现在提供的 current_lr,但是逻辑上的代码不应该直接返回缓存的 current_lr_ 就完事,而是应该给一个计算公式。
另外,torch 里提供的 _get_closed_form_lr 的接口,最终实际上是用于 step(int epoch) 这个 function 的,如果对应的 LRScheduler 派生类实现了这个 _get_closed_form_lr,就代表其支持 closed form 的跳步语义,然后 step(epoch) 会直接由提供的 function 计算出 current lr。而 torch 里面的 SequentialLR 派生类没有实现这个 function。
考虑到你这边的 GetClosedFormLR 定义为虚函数,要求所有派生类必须实现,我建议是在这里加上一个 // FIXME 的注释说明一下这一点,目前暂时先以返回一个 current lr 来 hack 实现,而不是提供了 closed-form 计算方法。
| }; | ||
|
|
||
| } // namespace lr_schedulers | ||
| } // namespace infini_train No newline at end of file |
There was a problem hiding this comment.
format 规范上,end of file 需要有一个 newline,后续也有几个文件存在这个问题
No description provided.