-
Notifications
You must be signed in to change notification settings - Fork 44
【训练营】学习率调度器实现 #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
【训练营】学习率调度器实现 #113
Changes from all commits
7a16589
0514862
81295e8
8e7cda0
1e65881
d924d3d
baca2ef
7df75d7
d0ac538
df4c68d
5b4ef6d
8c11dd9
6823244
fb9d997
7a29a61
b64566e
3a7abb4
f7b3fcb
1f95e29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| #include "infini_train/include/core/runtime/device_guard.h" | ||
| #include "infini_train/include/dataloader.h" | ||
| #include "infini_train/include/device.h" | ||
| #include "infini_train/include/lr_scheduler.h" | ||
| #include "infini_train/include/nn/modules/loss.h" | ||
| #include "infini_train/include/nn/modules/module.h" | ||
| #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" | ||
|
|
@@ -54,6 +55,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); | |
| // optimization | ||
| DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); | ||
| DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); | ||
| // lr scheduler | ||
| DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear"); | ||
| DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); | ||
| DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); | ||
| DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); | ||
| DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay"); | ||
| DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay"); | ||
| DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor"); | ||
| DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor"); | ||
| DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler"); | ||
| // evaluation | ||
| DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); | ||
| DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); | ||
|
|
@@ -247,6 +258,20 @@ void Train(const nn::parallel::Rank &rank) { | |
| optimizer = optimizer_creator(model->Parameters()); | ||
| } | ||
|
|
||
| LRSchedulerConfig sched_config; | ||
| sched_config.type = FLAGS_lr_scheduler; | ||
| sched_config.warmup_steps = FLAGS_warmup_steps; | ||
| sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor); | ||
| sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor); | ||
| sched_config.step_size = FLAGS_step_size; | ||
| sched_config.step_gamma = static_cast<float>(FLAGS_gamma); | ||
| sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor); | ||
| sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor); | ||
| sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用 | ||
| sched_config.constant_total_iters = FLAGS_lr_total_iters; | ||
| sched_config.linear_total_iters = FLAGS_lr_total_iters; | ||
| auto scheduler = CreateLRScheduler(optimizer, sched_config); | ||
|
|
||
| auto train_iter = train_loader.begin(); | ||
| std::shared_ptr<nn::Module> loss_fn | ||
| = (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>()) | ||
|
|
@@ -330,6 +355,9 @@ void Train(const nn::parallel::Rank &rank) { | |
| } | ||
|
|
||
| optimizer->Step(); | ||
| if (scheduler) { | ||
| scheduler->Step(); | ||
| } | ||
| } else { | ||
| auto [x, y] = *train_iter; | ||
| // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below | ||
|
|
@@ -339,6 +367,9 @@ void Train(const nn::parallel::Rank &rank) { | |
| y = std::make_shared<Tensor>(y->To(device)); | ||
|
|
||
| lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); | ||
| if (scheduler) { | ||
| scheduler->Step(); | ||
| } | ||
| } | ||
|
|
||
| if (ddp_world_size > 1) { | ||
|
|
@@ -354,11 +385,11 @@ void Train(const nn::parallel::Rank &rank) { | |
| if (rank.IsLastRank()) { | ||
| size_t used_mb = 0, reserved_mb = 0; | ||
| std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); | ||
|
|
||
| const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同理 |
||
| LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " | ||
| "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", | ||
| step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, | ||
| tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, | ||
| step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, | ||
| used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, | ||
| pp_world_size); | ||
|
|
||
| if ((step + 1) % FLAGS_freq_generate_txt == 0) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| #pragma once | ||
|
|
||
| #include <cmath> | ||
| #include <cstdint> | ||
| #include <functional> | ||
| #include <memory> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <variant> | ||
| #include <vector> | ||
|
|
||
| namespace infini_train { | ||
|
|
||
| class Optimizer; | ||
|
|
||
| using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>; | ||
| using StateDict = std::unordered_map<std::string, StateValue>; | ||
|
|
||
| struct LRSchedulerConfig { | ||
| std::string type = "none"; | ||
| // ConstantLR | ||
| float constant_factor = 1.0f / 3.0f; | ||
| int constant_total_iters = 5; | ||
| // StepLR | ||
| int64_t step_size = 10; | ||
| float step_gamma = 0.1f; | ||
| // LinearLR | ||
| float linear_start_factor = 1.0f / 3.0f; | ||
| float linear_end_factor = 1.0f; | ||
| int linear_total_iters = 5; | ||
| // LambdaLR | ||
| std::function<float(int64_t)> lambda_fn = nullptr; | ||
| // SequentialLR | ||
| std::vector<LRSchedulerConfig> sequential_configs; | ||
| std::vector<int64_t> sequential_milestones; | ||
| // ChainedScheduler | ||
| std::vector<LRSchedulerConfig> chained_configs; | ||
| // warmup | ||
| int64_t warmup_steps = 0; | ||
| float warmup_start_factor = 1.0f / 3.0f; | ||
| float warmup_end_factor = 1.0f; | ||
| }; | ||
|
|
||
| class LRScheduler { | ||
| public: | ||
| template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) { | ||
| auto scheduler = std::make_shared<T>(std::forward<Args>(args)...); | ||
| scheduler->InitialStep(); | ||
| return scheduler; | ||
| } | ||
|
|
||
| explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1); | ||
| virtual ~LRScheduler() = default; | ||
|
|
||
| LRScheduler(const LRScheduler &) = delete; | ||
| LRScheduler &operator=(const LRScheduler &) = delete; | ||
|
|
||
| virtual void Step(); | ||
| virtual void Step(int64_t epoch); | ||
| virtual void InitialStep(); | ||
|
|
||
| float GetLR() const; | ||
| float BaseLR() const; | ||
| int64_t LastStep() const; | ||
|
|
||
| void ResetStep(int64_t step = -1); | ||
| virtual StateDict State() const; | ||
| virtual void LoadState(const StateDict &state); | ||
|
|
||
| protected: | ||
| virtual float GetClosedFormLR() const = 0; | ||
| virtual float GetChainedFormLR() const; | ||
| void ApplyLR(float lr); | ||
|
|
||
| std::shared_ptr<Optimizer> optimizer_; | ||
| int64_t last_step_; | ||
| float current_lr_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current_lr_ 似乎也有点冗余,语义上 current_lr_ 和 optimizer_->GetLearningRate() 的值在任何时候应等价,现在在你的设计里看到这二者存在各自分开存且混用的状态(读完发现目前的 current_lr_ 像是 optimizer_->GetLearningRate() 的一个副本);目前的数值正确性上你处理的没问题,但是这种设计交给后人来扩展的时候很可能带来歧义。 建议针对“当前学习率”只保留唯一真状态来源,要么就全程由 optimizer_->GetLearningRate() 跟踪,lr scheduler 里面就不存 current lr 了;要么就由 lr scheduler 跟踪,每次计算完再 set 回 optimizer。个人认为前者较合适。 |
||
| float base_lr_; | ||
| bool is_initial_ = false; | ||
| }; | ||
|
|
||
| std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config); | ||
|
|
||
| namespace lr_schedulers { | ||
|
|
||
| class ConstantLR : public LRScheduler { | ||
| public: | ||
| ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, | ||
| int64_t last_step = -1); | ||
| ~ConstantLR() override = default; | ||
|
|
||
| protected: | ||
| float GetChainedFormLR() const override; | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const float factor_; | ||
| const int64_t total_iters_; | ||
| }; | ||
|
|
||
| class StepLR : public LRScheduler { | ||
| public: | ||
| StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); | ||
| ~StepLR() override = default; | ||
|
|
||
| protected: | ||
| float GetChainedFormLR() const override; | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const int64_t step_size_; | ||
| const float gamma_; | ||
| }; | ||
|
|
||
| class LinearLR : public LRScheduler { | ||
| public: | ||
| LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, | ||
| int64_t total_iters = 5, int64_t last_step = -1); | ||
| ~LinearLR() override = default; | ||
|
|
||
| protected: | ||
| float GetChainedFormLR() const override; | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const float start_factor_; | ||
| const float end_factor_; | ||
| const int64_t total_iters_; | ||
| }; | ||
|
|
||
| class LambdaLR : public LRScheduler { | ||
| public: | ||
| using LambdaFunc = std::function<float(int64_t)>; | ||
|
|
||
| LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); | ||
| ~LambdaLR() override = default; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const LambdaFunc lr_lambda_; | ||
| }; | ||
|
|
||
| class SequentialLR : public LRScheduler { | ||
| public: | ||
| SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers, | ||
| std::vector<int64_t> milestones, int64_t last_step = -1); | ||
| ~SequentialLR() override = default; | ||
|
|
||
| void Step() override; | ||
| void InitialStep() override; | ||
|
|
||
| StateDict State() const override; | ||
| void LoadState(const StateDict &state) override; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override { return current_lr_; } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块语义上不太对,我仔细看了下 torch 里面的实现,GetClosedFormLR 对标 torch 里提供的 get_closed_form_lr 的接口的话, 实际是想实现一个“给定 base_lr、last_step 以及其他超参,然后可以通过公式算出当前 lr 的 function”。这个虽然数值上确实等于你现在提供的 current_lr,但是逻辑上的代码不应该直接返回缓存的 current_lr_ 就完事,而是应该给一个计算公式。 另外,torch 里提供的 _get_closed_form_lr 的接口,最终实际上是用于 step(int epoch) 这个 function 的,如果对应的 LRScheduler 派生类实现了这个 _get_closed_form_lr,就代表其支持 closed form 的跳步语义,然后 step(epoch) 会直接由提供的 function 计算出 current lr。而 torch 里面的 SequentialLR 派生类没有实现这个 function。 考虑到你这边的 GetClosedFormLR 定义为虚函数,要求所有派生类必须实现,我建议是在这里加上一个 // FIXME 的注释说明一下这一点,目前暂时先以返回一个 current lr 来 hack 实现,而不是提供了 closed-form 计算方法。 |
||
| void UndoChildInitialSteps(); | ||
|
|
||
| private: | ||
| std::vector<std::shared_ptr<LRScheduler>> schedulers_; | ||
| std::vector<int64_t> milestones_; | ||
| }; | ||
|
|
||
| class ChainedScheduler : public LRScheduler { | ||
| public: | ||
| ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers, | ||
| int64_t last_step = -1); | ||
| ~ChainedScheduler() override = default; | ||
|
|
||
| void Step() override; | ||
| void InitialStep() override; | ||
|
|
||
| StateDict State() const override; | ||
| void LoadState(const StateDict &state) override; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override { return current_lr_; } | ||
|
|
||
| private: | ||
| std::vector<std::shared_ptr<LRScheduler>> schedulers_; | ||
| }; | ||
|
|
||
| } // namespace lr_schedulers | ||
| } // namespace infini_train | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. format 规范上,end of file 需要有一个 newline,后续也有几个文件存在这个问题 |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scheduler 在前面已经 Step 过了,所以这里 GetLR() 语义上是”下一步要用到的 lr“;而我们这里想打印的是每一步实际用到的 lr,所以这里的逻辑需要修改下。llama3 部分的 main.cc 里同理。