LoRA(Low-Rank Adaptation)的工作机制 - 低秩矩阵来微调全连接层
LoRA(Low-Rank Adaptation)的工作机制 - 低秩矩阵来微调全连接层
flyfish
LoRA 核心思想
LoRA 的目标是在微调大型预训练模型时,通过仅更新一小部分参数来提高效率和效果。这一过程的关键在于将权重更新 Δ W \Delta W ΔW 分解为低秩矩阵的形式,从而显著减少需要调整的参数数量。具体来说,原始的权重矩阵 W 0 W_0 W0 被保持不变,而权重更新 Δ W \Delta W ΔW 被分解为两个低秩矩阵 B B B 和 A A A 的乘积,即:
W = W 0 + Δ W = W 0 + B A W = W_0 + \Delta W = W_0 + BA W=W0+ΔW=W0+BA
其中:
- W 0 W_0 W0 是预训练模型的原始权重矩阵。
- B ∈ R d × r B \in \mathbb{R}^{d \times r} B∈Rd×r 和 A ∈ R r × k A \in \mathbb{R}^{r \times k} A∈Rr×k 是低秩矩阵, r r r 是这两个矩阵的秩,且 r < < min ( d , k ) r << \min(d, k) r<<min(d,k)。
训练过程
-
初始化:
- A A A 矩阵使用高斯分布进行随机初始化,这为模型引入了随机性,有助于探索更广泛的参数空间。
- B B B 矩阵初始化为零矩阵,确保在训练初期,模型的输出主要由原始预训练模型决定,然后逐渐加入低秩适配的影响。
-
前向传播:
- 输入 x x x 通过原始权重矩阵 W 0 W_0 W0 进行变换,得到 W 0 x W_0x W0x。
- 同时,输入 x x x 也通过低秩矩阵 A A A 和 B B B 进行变换,得到 B A x BAx BAx。
- 最终的输出是两者的加权和,即 W 0 x + α r B A x W_0x + \frac{\alpha}{r}BAx W0x+rαBAx,其中 α r \frac{\alpha}{r} rα 是一个缩放因子,用于平衡原始模型和低秩适配的贡献。
-
反向传播:
- 模型的损失函数通过对 A A A 和 B B B 的梯度进行更新来优化,而原始权重矩阵 W 0 W_0 W0 保持不变。
- 通过这种方式,模型可以在保持大部分预训练知识的同时,高效地学习特定任务的知识。
推理过程
在推理阶段,为了提高效率,通常会将训练好的低秩矩阵 A A A 和 B B B 合并到原始权重矩阵 W 0 W_0 W0 中,得到新的权重矩阵 W W W。具体步骤如下:
-
权重合并:
- 将训练好的 B B B 和 A A A 矩阵相乘,得到 B A BA BA。
- 将 B A BA BA 乘以缩放因子 α r \frac{\alpha}{r} rα,得到 α r B A \frac{\alpha}{r}BA rαBA。
- 将 α r B A \frac{\alpha}{r}BA rαBA 加到原始权重矩阵 W 0 W_0 W0 上,得到新的权重矩阵 W = W 0 + α r B A W = W_0 + \frac{\alpha}{r}BA W=W0+rαBA。
-
推理:
- 使用新的权重矩阵 W W W 对输入 x x x 进行推理,得到最终的输出。
通过仅更新低秩矩阵 A A A 和 B B B,大大减少了需要调整的参数数量,从而降低了训练时间和内存需求。 B B B 矩阵初始化为零矩阵,确保训练初期模型输出主要由预训练模型决定,逐步引入低秩适配的影响,有助于训练的稳定性和收敛性。通过调整缩放因子 α r \frac{\alpha}{r} rα,可以灵活控制低秩适配对最终输出的影响,平衡原始模型和新任务之间的关系。
举例子 LoRA 应用于全连接层
MergedLinear
类通过引入低秩矩阵来微调全连接层。
MergedLinear 类通过在全连接层中引入 LoRA 技术,实现了在不显著增加模型参数的情况下提升模型性能的目标。
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
import mathclass LoRALayer:def __init__(self, r: int, lora_alpha: int, lora_dropout: float,merge_weights: bool,):"""初始化 LoRA 层。:param r: 低秩矩阵的秩:param lora_alpha: LoRA 的缩放因子:param lora_dropout: LoRA 的 dropout 概率:param merge_weights: 是否在训练模式下合并权重"""self.r = rself.lora_alpha = lora_alpha# 可选的 dropout 层if lora_dropout > 0.:self.lora_dropout = nn.Dropout(p=lora_dropout)else:self.lora_dropout = lambda x: x # 如果没有 dropout,则返回输入本身# 标记权重是否已合并self.merged = Falseself.merge_weights = merge_weightsclass MergedLinear(nn.Linear, LoRALayer):# LoRA 实现在一个全连接层中def __init__(self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.,enable_lora: List[bool] = [False],fan_in_fan_out: bool = False,merge_weights: bool = True,**kwargs):"""初始化 MergedLinear 层。:param in_features: 输入特征数:param out_features: 输出特征数:param r: 低秩矩阵的秩:param lora_alpha: LoRA 的缩放因子:param lora_dropout: LoRA 的 dropout 概率:param enable_lora: 一个布尔列表,指示哪些部分启用 LoRA:param fan_in_fan_out: 是否交换权重矩阵的维度:param merge_weights: 是否在训练模式下合并权重:param kwargs: 其他传递给 nn.Linear 的参数"""nn.Linear.__init__(self, in_features, out_features, **kwargs) # 初始化 nn.LinearLoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) # 初始化 LoRALayerassert out_features % len(enable_lora) == 0, \'The length of enable_lora must divide out_features' # 确保 enable_lora 的长度能整除 out_featuresself.enable_lora = enable_loraself.fan_in_fan_out = fan_in_fan_out# 实际可训练的参数if r > 0 and any(enable_lora):self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), in_features))) # 低秩矩阵 Aself.lora_B = nn.Parameter(self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))) # 低秩矩阵 Bself.scaling = self.lora_alpha / self.r # 缩放因子# 冻结预训练的权重矩阵self.weight.requires_grad = False# 计算索引self.lora_ind = self.weight.new_zeros((out_features, ), dtype=torch.bool).view(len(enable_lora), -1)self.lora_ind[enable_lora, :] = Trueself.lora_ind = self.lora_ind.view(-1)self.reset_parameters() # 重置参数if fan_in_fan_out:self.weight.data = self.weight.data.transpose(0, 1) # 交换权重矩阵的维度def reset_parameters(self):"""重置参数。"""nn.Linear.reset_parameters(self) # 重置 nn.Linear 的参数if hasattr(self, 'lora_A'):# 以默认方式初始化 A,B 初始化为零nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))nn.init.zeros_(self.lora_B)def zero_pad(self, x):"""对 x 进行零填充。:param x: 输入张量:return: 零填充后的张量"""result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) # 创建零填充的结果张量result[self.lora_ind] = x # 将 x 的值填充到结果张量的指定位置return resultdef merge_AB(self):"""计算低秩矩阵 A 和 B 的乘积。:return: 低秩矩阵 A 和 B 的乘积"""def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else w # 交换权重矩阵的维度delta_w = F.conv1d(self.lora_A.unsqueeze(0), # 扩展维度以适应卷积操作self.lora_B.unsqueeze(-1), # 扩展维度以适应卷积操作groups=sum(self.enable_lora) # 分组卷积).squeeze(0) # 移除多余的维度return T(self.zero_pad(delta_w)) # 返回零填充后的结果def train(self, mode: bool = True):"""设置训练模式。:param mode: 是否为训练模式"""def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else w # 交换权重矩阵的维度nn.Linear.train(self, mode) # 设置 nn.Linear 的训练模式if mode:if self.merge_weights and self.merged:# 确保权重未合并if self.r > 0 and any(self.enable_lora):self.weight.data -= self.merge_AB() * self.scaling # 从权重中减去低秩矩阵的贡献self.merged = Falseelse:if self.merge_weights and not self.merged:# 合并权重并标记if self.r > 0 and any(self.enable_lora):self.weight.data += self.merge_AB() * self.scaling # 将低秩矩阵的贡献加到权重上self.merged = True def forward(self, x: torch.Tensor):"""前向传播。:param x: 输入张量:return: 输出张量"""def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else w # 交换权重矩阵的维度if self.merged:return F.linear(x, T(self.weight), bias=self.bias) # 如果已合并权重,直接进行线性变换else:result = F.linear(x, T(self.weight), bias=self.bias) # 进行线性变换if self.r > 0:result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling # 添加 LoRA 的贡献return result
说明
-
LoRALayer 类:
__init__
方法:- 初始化 LoRA 层的参数,包括低秩矩阵的秩
r
、缩放因子lora_alpha
、dropout 概率lora_dropout
和是否合并权重merge_weights
。 - 如果
lora_dropout
大于 0,则创建一个nn.Dropout
层;否则,创建一个恒等函数。 - 标记权重是否已合并
self.merged
,并设置self.merge_weights
。
- 初始化 LoRA 层的参数,包括低秩矩阵的秩
-
MergedLinear 类:
__init__
方法:- 初始化
nn.Linear
和LoRALayer
。 - 确保
enable_lora
的长度能整除out_features
,以确保每个部分的输出特征数相等。 - 如果启用了 LoRA,初始化低秩矩阵
lora_A
和lora_B
,并冻结预训练的权重矩阵。 - 计算索引
lora_ind
,用于确定哪些部分启用 LoRA。 - 重置参数,并根据
fan_in_fan_out
交换权重矩阵的维度。
- 初始化
reset_parameters
方法:- 重置
nn.Linear
的参数。 - 如果存在
lora_A
,初始化lora_A
和lora_B
。
- 重置
zero_pad
方法:- 对输入张量进行零填充,以便与原权重矩阵的形状匹配。
merge_AB
方法:- 计算低秩矩阵
lora_A
和lora_B
的乘积。 - 返回零填充后的结果。
- 计算低秩矩阵
train
方法:- 设置训练模式。
- 根据
merge_weights
和merged
标记,决定是否合并或分离权重。
forward
方法:- 前向传播。
- 根据是否已合并权重,进行线性变换,并添加 LoRA 的贡献。
MergedLinear
类的主要功能是在神经网络的全连接层中应用 Low-Rank Adaptation (LoRA) 技术。具体来说,它通过在预训练模型的权重矩阵上添加两个低秩矩阵(lora_A
和 lora_B
),来微调模型以适应特定任务。这两个低秩矩阵的乘积会被加到原始权重矩阵上。