LoRA(Low-Rank Adaptation)的工作机制 - 在 GPT-2 的注意力层中添加 LoRA 低秩适配器
LoRA(Low-Rank Adaptation)的工作机制 - 在 GPT-2 的注意力层中添加 LoRA 低秩适配器
flyfish
GPT-2 是一个语言模型,它通过多层的 Transformer 架构来生成和理解文本。Transformer 架构中最核心的部分是 注意力机制(Attention Mechanism)。
代码详解及注释
Attention
类的解释
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import loralib as loraclass Attention(nn.Module):def __init__(self, nx, n_ctx, config, scale=False):super(Attention, self).__init__()n_state = nx # 在 Attention 中,n_state 等于 nx(即 n_embd)# [从 Block 切换到 Attention,使 nx => n_state 保持与 TensorFlow 实现一致]assert n_state % config.n_head == 0 # 确保 n_state 可以被 n_head 整除self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))# 注册一个下三角矩阵作为注意力掩码,用于因果关系建模self.n_head = config.n_head # 注意力头的数量self.split_size = n_state # 每个头的维度self.scale = scale # 是否缩放注意力分数# 使用 LoRA 的 MergedLinear 层,将输入映射到 query, key, valueself.c_attn = lora.MergedLinear(nx, n_state * 3, r=config.lora_attn_dim, # LoRA 的秩lora_alpha=config.lora_attn_alpha, # LoRA 的缩放因子lora_dropout=config.lora_dropout, # LoRA 的 dropout 概率enable_lora=[True, False, True], # 仅对 query 和 value 应用 LoRAfan_in_fan_out=True, # 输入和输出的维度顺序merge_weights=False # 不在训练时合并权重)self.c_proj = Conv1D(n_state, nx) # 用于将注意力结果投影回原始维度self.config = config # 配置对象def _attn(self, q, k, v, len_kv=None):# 计算注意力分数w = torch.matmul(q, k)if self.scale:w = w / math.sqrt(v.size(-1)) # 缩放注意力分数nd, ns = w.size(-2), w.size(-1)b = self.bias[:, :, ns-nd:ns, :ns]w = w * b - 1e10 * (1 - b) # 应用因果掩码# q : (batch, head, q_seq_length, head_features)# k : (batch, head, head_features, kv_seq_length)# w : (batch, head, q_seq_length, kv_seq_length)# v : (batch, head, kv_seq_length, head_features)if len_kv is not None:_len = torch.arange(k.size(-1), device=k.device)_input_msk = _len[None, :] >= (len_kv)[:, None]w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10) # 对超出长度的部分进行掩码w = nn.Softmax(dim=-1)(w) # 应用 Softmax 函数return torch.matmul(w, v) # 计算加权和def merge_heads(self, x):# 将多头注意力的结果合并回原始维度x = x.permute(0, 2, 1, 3).contiguous()new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)return x.view(*new_x_shape) # 形状变为 (batch, seq_length, n_state)def split_heads(self, x, k=False):# 将输入张量拆分为多个头new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)x = x.view(*new_x_shape) # 形状变为 (batch, seq_length, n_head, head_features)if k:return x.permute(0, 2, 3, 1).contiguous() # (batch, head, head_features, seq_length)else:return x.permute(0, 2, 1, 3).contiguous() # (batch, head, seq_length, head_features)def forward(self, x, history=None, layer_past=None, len_past=None):hidden_states = x # 输入隐藏状态x = self.c_attn(x) # 将输入映射到 query, key, valuequery, key, value = x.split(self.split_size, dim=2) # 拆分 query, key, valuequery = self.split_heads(query) # 将 query 拆分为多个头key = self.split_heads(key, k=True) # 将 key 拆分为多个头,并转置value = self.split_heads(value) # 将 value 拆分为多个头len_kv = None # 初始化 kv 序列长度if layer_past is not None:# 如果有过去的层状态(用于推理时的高效计算)if len_past is None:past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # 转置过去的 key 和 valuekey = torch.cat((past_key, key), dim=-1) # 拼接过去的 key 和当前的 keyvalue = torch.cat((past_value, value), dim=-2) # 拼接过去的 value 和当前的 valueelse:key_seq = key.shape[-1]assert key_seq == 1 # 确保 key 序列长度为 1_batch = torch.arange(0, key.shape[0], dtype=torch.long, device=key.device)past_key, past_value = layer_past[0], layer_past[1]past_key[_batch,:,len_past,:] = key.squeeze(-1) # 更新过去的 keypast_value[_batch,:,len_past,:] = value.squeeze(-2) # 更新过去的 valuekey = past_key.transpose(-2, -1) # 转置 keyvalue = past_valuelen_kv = len_past + 1 # 更新 kv 序列长度present = torch.stack((key.transpose(-2, -1), value)) # 将 key 和 value 堆叠在一起,形成当前层的状态a = self._attn(query, key, value, len_kv=len_kv) # 计算注意力a = self.merge_heads(a) # 合并多头注意力的结果a = self.c_proj(a) # 将注意力结果投影回原始维度return a, present # 返回注意力结果和当前层的状态
详细解释
-
初始化 (
__init__
方法):n_state
:等于nx
,表示每个注意力头的维度。bias
:注册一个下三角矩阵作为注意力掩码,用于因果关系建模。c_attn
:使用lora.MergedLinear
层将输入映射到query
、key
和value
。这里使用了 LoRA 来减少训练参数量。c_proj
:用于将注意力结果投影回原始维度的线性层。
-
注意力计算 (
_attn
方法):- 计算注意力分数
w
,即query
和key
的点积。 - 如果
scale
为True
,则对注意力分数进行缩放,防止梯度消失或爆炸。 - 应用因果掩码,确保每个位置只能看到之前的序列。
- 如果有
len_kv
,则对超出长度的部分进行掩码。 - 应用 Softmax 函数,将注意力分数转换为概率分布。
- 计算加权和,即
query
对value
的加权求和。
- 计算注意力分数
-
多头注意力的合并 (
merge_heads
方法):- 将多头注意力的结果从
(batch, head, seq_length, head_features)
合并为(batch, seq_length, n_state)
。
- 将多头注意力的结果从
-
多头注意力的拆分 (
split_heads
方法):- 将输入张量从
(batch, seq_length, n_state)
拆分为(batch, seq_length, n_head, head_features)
。 - 如果是
key
,则进行转置,使其形状为(batch, head, head_features, seq_length)
。
- 将输入张量从
-
前向传播 (
forward
方法):- 将输入映射到
query
、key
和value
。 - 将
query
、key
和value
拆分为多个头。 - 如果有过去的层状态,则拼接过去的
key
和value
,以提高推理效率。 - 计算注意力,并将结果合并回原始维度。
- 返回注意力结果和当前层的状态。
- 将输入映射到
理论
-
注意力机制:
- 注意力机制的核心思想是通过计算查询(query)和键(key)之间的相似度来确定值(value)的权重。
- 注意力分数 w w w 的计算公式为:
w = query ⋅ key T d k w = \frac{\text{query} \cdot \text{key}^T}{\sqrt{d_k}} w=dkquery⋅keyT - 其中 d k d_k dk 是
key
的维度,用于缩放注意力分数,防止梯度消失或爆炸。 - 应用 Softmax 函数将注意力分数转换为概率分布:
attention_weights = softmax ( w ) \text{attention\_weights} = \text{softmax}(w) attention_weights=softmax(w) - 最终的注意力结果 a a a 通过加权求和得到:
a = attention_weights ⋅ value a = \text{attention\_weights} \cdot \text{value} a=attention_weights⋅value
-
多头注意力:
- 多头注意力通过将输入拆分为多个头,分别计算注意力,然后将结果合并,从而捕捉不同子空间的信息。
- 每个头的计算公式为:
head i = Attention ( query i , key i , value i ) \text{head}_i = \text{Attention}(\text{query}_i, \text{key}_i, \text{value}_i) headi=Attention(queryi,keyi,valuei) - 最终的多头注意力结果通过将所有头的结果拼接并投影回原始维度:
multihead = Concat ( head 1 , head 2 , … , head h ) ⋅ W O \text{multihead} = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \cdot W^O multihead=Concat(head1,head2,…,headh)⋅WO - 其中 h h h 是头的数量, W O W^O WO 是投影矩阵。
-
LoRA:
- LoRA 通过引入低秩矩阵 A A A 和 B B B 来适应特定任务,而原始权重 W W W 保持不变。
- 低秩矩阵的形式为:
W new = W + A ⋅ B W_{\text{new}} = W + A \cdot B Wnew=W+A⋅B - 通过这种方式,LoRA 可以显著减少训练参数量,同时保持模型的性能。