当前位置: 首页 > news >正文

LoRA(Low-Rank Adaptation)的工作机制 - 在 GPT-2 的注意力层中添加 LoRA 低秩适配器

LoRA(Low-Rank Adaptation)的工作机制 - 在 GPT-2 的注意力层中添加 LoRA 低秩适配器

flyfish

GPT-2 是一个语言模型,它通过多层的 Transformer 架构来生成和理解文本。Transformer 架构中最核心的部分是 注意力机制(Attention Mechanism)。

代码详解及注释

Attention 类的解释

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import loralib as loraclass Attention(nn.Module):def __init__(self, nx, n_ctx, config, scale=False):super(Attention, self).__init__()n_state = nx  # 在 Attention 中,n_state 等于 nx(即 n_embd)# [从 Block 切换到 Attention,使 nx => n_state 保持与 TensorFlow 实现一致]assert n_state % config.n_head == 0  # 确保 n_state 可以被 n_head 整除self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))# 注册一个下三角矩阵作为注意力掩码,用于因果关系建模self.n_head = config.n_head  # 注意力头的数量self.split_size = n_state  # 每个头的维度self.scale = scale  # 是否缩放注意力分数# 使用 LoRA 的 MergedLinear 层,将输入映射到 query, key, valueself.c_attn = lora.MergedLinear(nx, n_state * 3, r=config.lora_attn_dim,  # LoRA 的秩lora_alpha=config.lora_attn_alpha,  # LoRA 的缩放因子lora_dropout=config.lora_dropout,  # LoRA 的 dropout 概率enable_lora=[True, False, True],  # 仅对 query 和 value 应用 LoRAfan_in_fan_out=True,  # 输入和输出的维度顺序merge_weights=False  # 不在训练时合并权重)self.c_proj = Conv1D(n_state, nx)  # 用于将注意力结果投影回原始维度self.config = config  # 配置对象def _attn(self, q, k, v, len_kv=None):# 计算注意力分数w = torch.matmul(q, k)if self.scale:w = w / math.sqrt(v.size(-1))  # 缩放注意力分数nd, ns = w.size(-2), w.size(-1)b = self.bias[:, :, ns-nd:ns, :ns]w = w * b - 1e10 * (1 - b)  # 应用因果掩码# q : (batch, head, q_seq_length, head_features)# k : (batch, head, head_features, kv_seq_length)# w : (batch, head, q_seq_length, kv_seq_length)# v : (batch, head, kv_seq_length, head_features)if len_kv is not None:_len = torch.arange(k.size(-1), device=k.device)_input_msk =  _len[None, :] >= (len_kv)[:, None]w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10)  # 对超出长度的部分进行掩码w = nn.Softmax(dim=-1)(w)  # 应用 Softmax 函数return torch.matmul(w, v)  # 计算加权和def merge_heads(self, x):# 将多头注意力的结果合并回原始维度x = x.permute(0, 2, 1, 3).contiguous()new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)return x.view(*new_x_shape)  # 形状变为 (batch, seq_length, n_state)def split_heads(self, x, k=False):# 将输入张量拆分为多个头new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)x = x.view(*new_x_shape)  # 形状变为 (batch, seq_length, n_head, head_features)if k:return x.permute(0, 2, 3, 1).contiguous()  # (batch, head, head_features, seq_length)else:return x.permute(0, 2, 1, 3).contiguous()  # (batch, head, seq_length, head_features)def forward(self, x, history=None, layer_past=None, len_past=None):hidden_states = x  # 输入隐藏状态x = self.c_attn(x)  # 将输入映射到 query, key, valuequery, key, value = x.split(self.split_size, dim=2)  # 拆分 query, key, valuequery = self.split_heads(query)  # 将 query 拆分为多个头key = self.split_heads(key, k=True)  # 将 key 拆分为多个头,并转置value = self.split_heads(value)  # 将 value 拆分为多个头len_kv = None  # 初始化 kv 序列长度if layer_past is not None:# 如果有过去的层状态(用于推理时的高效计算)if len_past is None:past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # 转置过去的 key 和 valuekey = torch.cat((past_key, key), dim=-1)  # 拼接过去的 key 和当前的 keyvalue = torch.cat((past_value, value), dim=-2)  # 拼接过去的 value 和当前的 valueelse:key_seq = key.shape[-1]assert key_seq == 1  # 确保 key 序列长度为 1_batch = torch.arange(0, key.shape[0], dtype=torch.long, device=key.device)past_key, past_value = layer_past[0], layer_past[1]past_key[_batch,:,len_past,:] = key.squeeze(-1)  # 更新过去的 keypast_value[_batch,:,len_past,:] = value.squeeze(-2)  # 更新过去的 valuekey = past_key.transpose(-2, -1)  # 转置 keyvalue = past_valuelen_kv = len_past + 1  # 更新 kv 序列长度present = torch.stack((key.transpose(-2, -1), value))  # 将 key 和 value 堆叠在一起,形成当前层的状态a = self._attn(query, key, value, len_kv=len_kv)  # 计算注意力a = self.merge_heads(a)  # 合并多头注意力的结果a = self.c_proj(a)  # 将注意力结果投影回原始维度return a, present  # 返回注意力结果和当前层的状态

详细解释

  1. 初始化 (__init__ 方法)

    • n_state:等于 nx,表示每个注意力头的维度。
    • bias:注册一个下三角矩阵作为注意力掩码,用于因果关系建模。
    • c_attn:使用 lora.MergedLinear 层将输入映射到 querykeyvalue。这里使用了 LoRA 来减少训练参数量。
    • c_proj:用于将注意力结果投影回原始维度的线性层。
  2. 注意力计算 (_attn 方法)

    • 计算注意力分数 w,即 querykey 的点积。
    • 如果 scaleTrue,则对注意力分数进行缩放,防止梯度消失或爆炸。
    • 应用因果掩码,确保每个位置只能看到之前的序列。
    • 如果有 len_kv,则对超出长度的部分进行掩码。
    • 应用 Softmax 函数,将注意力分数转换为概率分布。
    • 计算加权和,即 queryvalue 的加权求和。
  3. 多头注意力的合并 (merge_heads 方法)

    • 将多头注意力的结果从 (batch, head, seq_length, head_features) 合并为 (batch, seq_length, n_state)
  4. 多头注意力的拆分 (split_heads 方法)

    • 将输入张量从 (batch, seq_length, n_state) 拆分为 (batch, seq_length, n_head, head_features)
    • 如果是 key,则进行转置,使其形状为 (batch, head, head_features, seq_length)
  5. 前向传播 (forward 方法)

    • 将输入映射到 querykeyvalue
    • querykeyvalue 拆分为多个头。
    • 如果有过去的层状态,则拼接过去的 keyvalue,以提高推理效率。
    • 计算注意力,并将结果合并回原始维度。
    • 返回注意力结果和当前层的状态。

理论

  • 注意力机制

    • 注意力机制的核心思想是通过计算查询(query)和键(key)之间的相似度来确定值(value)的权重。
    • 注意力分数 w w w 的计算公式为:
      w = query ⋅ key T d k w = \frac{\text{query} \cdot \text{key}^T}{\sqrt{d_k}} w=dk querykeyT
    • 其中 d k d_k dkkey 的维度,用于缩放注意力分数,防止梯度消失或爆炸。
    • 应用 Softmax 函数将注意力分数转换为概率分布:
      attention_weights = softmax ( w ) \text{attention\_weights} = \text{softmax}(w) attention_weights=softmax(w)
    • 最终的注意力结果 a a a 通过加权求和得到:
      a = attention_weights ⋅ value a = \text{attention\_weights} \cdot \text{value} a=attention_weightsvalue
  • 多头注意力

    • 多头注意力通过将输入拆分为多个头,分别计算注意力,然后将结果合并,从而捕捉不同子空间的信息。
    • 每个头的计算公式为:
      head i = Attention ( query i , key i , value i ) \text{head}_i = \text{Attention}(\text{query}_i, \text{key}_i, \text{value}_i) headi=Attention(queryi,keyi,valuei)
    • 最终的多头注意力结果通过将所有头的结果拼接并投影回原始维度:
      multihead = Concat ( head 1 , head 2 , … , head h ) ⋅ W O \text{multihead} = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \cdot W^O multihead=Concat(head1,head2,,headh)WO
    • 其中 h h h 是头的数量, W O W^O WO 是投影矩阵。
  • LoRA

    • LoRA 通过引入低秩矩阵 A A A B B B 来适应特定任务,而原始权重 W W W 保持不变。
    • 低秩矩阵的形式为:
      W new = W + A ⋅ B W_{\text{new}} = W + A \cdot B Wnew=W+AB
    • 通过这种方式,LoRA 可以显著减少训练参数量,同时保持模型的性能。

http://www.mrgr.cn/news/66270.html

相关文章:

  • 管易到金蝶销售数据集成全流程详解
  • LeetCode 热题100 之 回溯1
  • yaml文件编写
  • FFmpeg 4.3 音视频-多路H265监控录放C++开发十二:在屏幕上显示多路视频播放,可以有不同的分辨率,格式和帧率。
  • clickhouse运维篇(三):生产环境一键生成配置并快速部署ck集群
  • 基于SSM+微信小程序的社团登录管理系统(社团1)
  • Git遇到“fatal: bad object refs/heads/master - 副本”问题的解决办法
  • 基于 GADF+Swin-CNN-GAM 的高创新轴承故障诊断模型
  • 41.第二阶段x86游戏实战2-C++实现lua寻路
  • 基于STM32的自动化植物浇灌系统教学
  • 【Qt】使用Qt发送http请求封装一个通用类
  • 劫持微信聊天记录并分析还原 —— 解密数据库(二)
  • 工作中问题
  • 新一代跟踪器StrongSORT: Make DeepSORT Great Again论文解析—让 DeepSORT 再次伟大
  • nacos本地虚拟机搭建切换wiff问题
  • 基于SpringBoot的免税商品优选购物商城的设计与实现
  • 小美和大富翁
  • 动态规划 —— dp问题-按摩师
  • Docker 的基本概念和优势
  • 气体传感器种类详解:从半导体到红外吸收型的全面解析
  • 仿真APP助力汽车零部件厂商打造核心竞争力
  • 解决从huggingface.co下载模型失败问题
  • EasyQBlog .NET 8 + Q-Blog 2.0博客模板 + easyweb iframe后台模板 开发的个人博客
  • 树莓派开发相关知识十 -小车服务器
  • Python打包脚本为EXE可执行文件
  • 信息安全工程师(77)常见网络安全应急事件场景与处理流程