Skip to content

feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108

Merged
kilinchange merged 4 commits intomasterfrom
add_lora
Mar 16, 2026
Merged

feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108
kilinchange merged 4 commits intomasterfrom
add_lora

Conversation

@chen2021673
Copy link
Contributor

@chen2021673 chen2021673 commented Feb 12, 2026

Summary

Added LoRA (Low-Rank Adaptation) support enabling parameter-efficient fine-tuning of large models through low-rank decomposition, significantly reducing the number of trainable parameters.

Key Features

  • LoRA Infrastructure: Configurable rank, alpha, and dropout parameters
  • Seamless Integration: LoRALinear wrapper for easy integration with existing Linear layers
  • Tensor Parallelism: LoRAParallelLinear supports TP/SP distributed training
  • DDP Integration: Properly handles LoRA parameters in distributed data parallel
  • Core APIs:
    • GetLoRAParameters() - Retrieve LoRA trainable parameters for optimizer
    • MergeAndUnload() - Merge LoRA weights into base model

New Files

  • infini_train/include/nn/lora/ - Headers (lora_config, lora_linear, lora_model, lora_parallel_linear, lora_utils)
  • infini_train/src/nn/lora/ - Implementations
  • test/lora/test_lora.cc - Unit tests
  • docs/lora_usage.md - Usage documentation

Examples

  • example/gpt2/main.cc - GPT2 LoRA fine-tuning example
  • example/llama3/main.cc - LLaMA3 LoRA fine-tuning example

Test Result

llama3 运行结果精度/性能与历史数据对比:
image

llama3 运行结果精度/性能与pytorch对比:
image
image

llama3多机运行结果:
不开启lora:
image
开启lora:
image

* parallel::global::GetTensorParallelSize(),
base_module->bias(), base_module->gather_output(), base_module->input_is_parallel(),
base_module->skip_bias_add(), base_module->sequence_parallel()),
config_(config) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先记录一下,可能还得讨论。

看完这个继承的版本以后突然意识到这个写法的别扭之处了:多态比如基类 A 和子类 B,基本上使用上,大部分情况应该会直接调用 B 的构造;现在程序中 A 和 B 的构造并非同时,程序会先构造 A,然后再构造 B;然而构造 B 的时候,又在构造函数中接收一个 A 实例,并使用其参数重新调用了 A 的构造函数构造父类对象。

这样的弊端:

  1. 写法反而更加交错复杂,没节省多少篇幅;
  2. 使用多态基本上是出于 A 既可是 B,B 也是 A 的场景考虑;现在这样会同时存在两个 RowParallelLinear 实例(base_module 和继承产生的子对象),原先的那个就突然弃掉了,行为上反而像 B 替代 A;

既然 B 替代 A,LoRA 实际实现还应该真的更像 decorator 的角色,更适合组合而不是继承。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我个人还是更倾向于用继承实现 LoRALinear。从语义上看,LoRALinear 本质上仍然是 Linear 的一种实现(y = xW + xBA),更接近 is-a 而不是 has-a 的关系。另外,从 checkpoint 语义上看,LoRA 也更像是 Linear 的附加参数,而不是一种新的 module 类型。(https://huggingface.co/docs/peft/en/developer_guides/checkpoint)

image

These LoRA matrices are implemented as nn.Linear layers, so the parameters are stored in the .weight attribute (lora_A.weight, lora_B.weight).

comment 里提到的构造链条问题,我理解主要是当前实现通过 base_module 的参数重新构造父类对象导致的。如果改为让 LoRALinear 直接构造父类部分参数,而不是接收一个已有 Linear 再复制其状态,这个问题应该可以避免(但这个不强求修改)。

另外,关于 “使用其参数重新调用 A 的构造函数构造父类对象” 的方式,导致传入的基类成员无法访问必要内部状态的问题,可以考虑将子类声明为基类的 friend,这样子类可以直接访问必要的内部状态。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRALinear 在语义上可以看作 Linear 的一种扩展,所以继承是有道理的。但除了考虑语义以外,也要考虑构造方式的合理性。
目前是先有一个 base module,再拿它的参数重新构造子类里的父类部分,因为子类是后期被构造进而替换掉原有module的,不能直接拿到构造参数,这也是我认为更像 decorator 的地方。如果子类为了完成构造,必须访问基类的大量私有内部状态,甚至要靠 friend 实现,那已经失去了子类的简洁性和自然语义,与原有父类的关系也很模糊,不如直接声明成另一个类。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在语义上,LoRALinear 使用继承的实现是合理的,但是也得承认使用了 inject 的情况下现有的继承实现,在构造方式上有冗余的部分。至于 inject 方式的 LoRALinear 嵌入方法,我的建议是这样的:

  • 首先我们设计已经定稿了,在没有严重问题的情况下我们不建议现在修改;
  • 其次业界已有的实现事实上选择了这种方式,我们在通常情况下还是选择兼容他们的设计。

于是在不修改 inject 这种设计的情况下,继承和组合的实现各有自身的优点和缺点,我不继续做强制要求,在不侵入 module 基类的情况下选择权交给你。至于 inject 的设计是否有修改的必要,可以在将来有需求的时候重新讨论。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实感觉两种方式都有道理,直觉上继承是理所应当的,现在体现的主要是写法上的问题,核心似乎也只是在构造函数函数这里,涉及一堆 getter 的调用。

另外看了下 peft 里面的等价我们 LoRALayer 的定义,decorator 行为比较显而易见,所有要加的部分都写在 LoraLayer 类里面传给了 Linear。这个 Linear 也是直接继承自 nn.Module,跟 nn.Linear 没直接继承关系。
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前如果确定不下来更优解法,就先按这个合吧,至少减少一点重复劳动。等 transformer model 那一套建设合入以后再整体看看

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先按继承合入

chen2021673 and others added 3 commits March 16, 2026 05:53
- Add LoRA module infrastructure with configurable rank, alpha, dropout
- Implement LoRALinear wrapper for seamless integration with Linear layers
- Support tensor parallelism via LoRAParallelLinear
- Add LoRAModel utility for managing multiple LoRA layers
- Integrate LoRA configuration and utilities
- Add GPT2 example demonstrating LoRA fine-tuning
- Include comprehensive usage documentation and test suite

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Refactor LoRA config construction with proper target module parsing
- Add GetLoRAModel for in-place LoRA layer injection
- Fix DDP reducer to correctly handle LoRA parameters
- Fix RowParallel/ColumnParallel LoRA input handling to match base module behavior
- Add shape-based defensive checks for TP/SP consistency
- Move TP/SP communication helper function declarations to utils.h
- Move getter implementations from header to .cc file
- Add unit test for SaveLoRAWeights/LoadLoRAWeights functionality

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kilinchange
Copy link
Collaborator

kilinchange commented Mar 16, 2026

麻烦贴一下多机测试截图(3d 并行 case,开/不开 lora)

- Refactor GetLoRAParameters() to retrieve only LoRA parameters for optimizer
- Add MergeAndUnload() to merge weights and export as standard model
- Update gpt2/llama3 examples to use new GetLoRAParameters API
- Refactor LoRA linear modules and fix dimension mismatch
- Improve LoRA tests and update documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kilinchange kilinchange merged commit bdec219 into master Mar 16, 2026
2 checks passed
@kilinchange kilinchange deleted the add_lora branch March 16, 2026 09:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants