当前位置: 首页 > news >正文

仅仅使用pytorch来手撕transformer架构(2):多头注意力MultiHeadAttention类的实现和向前传播

手撕MultiHeadAttention 类的代码,结合具体的例子来说明每一步的作用和计算过程。

往期文章:
仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播

最适合小白入门的Transformer介绍

1. 初始化方法 __init__

def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

1.1参数解释

  • embed_size:嵌入向量的维度,表示每个输入向量的大小。
  • heads:注意力头的数量。多头注意力机制将输入分割成多个“头”,每个头学习不同的特征。
  • head_dim:每个注意力头的维度大小,计算公式为 embed_size // heads。这意味着每个头处理的特征子集的大小。

1.2线性变换层

  • self.valuesself.keysself.queries

    • 这些是线性变换层,用于将输入的嵌入向量分别转换为值(Values)、键(Keys)和查询(Queries)。
    • 每个线性层的输入和输出维度都是 self.head_dim,因为每个头处理的特征子集大小为 self.head_dim
    • 使用 bias=False 是为了简化计算,避免引入额外的偏置项。
  • self.fc_out

    • 在多头注意力计算完成后,将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度 embed_size

2. 前向传播方法 forward

def forward(self, values, keys, query, mask):N = query.shape[0]  # Batch sizevalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

2.1输入参数

  • valueskeysquery
    • 这三个输入张量的形状通常为 (batch_size, seq_len, embed_size)
    • 它们分别对应于值(Values)、键(Keys)和查询(Queries)。
  • mask
    • 用于遮蔽某些位置的注意力权重,避免模型关注到不应该关注的部分(例如,解码器中的未来信息)。

2.2多头注意力计算过程

2.2.1 将输入嵌入分割为多个头:
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
  • 将输入的嵌入向量分割成 heads 个头,每个头的维度为 self.head_dim
  • 例如,如果 embed_size = 256heads = 8,则 self.head_dim = 32,每个头处理 32 维的特征。
  • 重塑后的形状为 (N, seq_len, heads, head_dim)
2.2.2 线性变换:
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
  • 对每个头的值、键和查询分别进行线性变换。
  • 这一步将输入特征投影到不同的子空间中,使得每个头可以学习不同的特征。
2.2.3计算注意力分数(Attention Scores):
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
  • 使用 torch.einsum 计算查询和键之间的点积,得到注意力分数矩阵。
  • 公式 nqhd,nkhd->nhqk 表示:
    • n:批量大小(Batch Size)。
    • q:查询序列的长度。
    • k:键序列的长度。
    • h:头的数量。
    • d:每个头的维度。
  • 输出的 energy 形状为 (N, heads, query_len, key_len)
2.2.4应用掩码(Masking):
if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))
  • 如果提供了掩码,将掩码为 0 的位置的注意力分数设置为一个非常小的值(如 -1e20),这样在后续的 softmax 计算中,这些位置的权重会趋近于 0。
2.2.5计算注意力权重:
attention = torch.softmax(energy / (self.embed_size ** (0.5)), dim=3)
  • 对注意力分数进行 softmax 归一化,得到注意力权重。
  • 除以 sqrt(embed_size) 是为了缩放点积结果,避免梯度消失或爆炸。
2.2.6应用注意力权重:
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim
)
  • 使用 torch.einsum 将注意力权重与值相乘,得到加权的值。
  • 公式 nhql,nlhd->nqhd 表示:
    • n:批量大小。
    • h:头的数量。
    • q:查询序列的长度。
    • l:值序列的长度。
    • d:每个头的维度。
  • 输出的 out 形状为 (N, query_len, heads * self.head_dim)
2.2.7线性变换输出:
out = self.fc_out(out)
  • 将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度 embed_size

3. 示例矩阵计算

假设:

  • embed_size = 4
  • heads = 2
  • head_dim = embed_size // heads = 2
  • 输入序列长度为 3,批量大小为 1。

3.1输入张量

values = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]], dtype=torch.float32)
keys = torch.tensor([[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=torch.float32)
query = torch.tensor([[[25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36]]], dtype=torch.float32)
mask = None

3.2重塑为多头

values = values.reshape(1, 3, 2, 2)  # (N, value_len, heads, head_dim)
keys = keys.reshape(1, 3, 2, 2)
queries = query.reshape(1, 3, 2, 2)

3.3线性变换

假设线性变换层的权重为单位矩阵(简化计算),则:

values = self.values(values)  # 不改变值
keys = self.keys(keys)
queries = self.queries(queries)

3.4计算注意力分数

energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

假设:

  • queries = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
  • keys = [[[13, 14], [15, 16]], [[17, 18], [19, 20]], [[21, 22], [23, 24]]]

计算点积:

energy = [[[[1*13 + 2*14, 1*15 + 2*16], [1*17 + 2*18, 1*19 + 2*20]],

完整代码:

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert self.head_dim * heads == embed_size, "嵌入尺寸需要被头部整除"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return out

作者码字不易,觉得有用的话不妨点个赞吧,关注我,持续为您更新AI的优质内容。


http://www.mrgr.cn/news/93966.html

相关文章:

  • NX二次开发,创建基准平面
  • 正则表达式(2)匹配规则
  • ①Modbus TCP转Modbus RTU/ASCII网关同步采集无需编程高速轻松组网
  • AI学习记录 - PPO算法草稿
  • LiveCommunicationKit OC 实现
  • 力扣热题 100:二叉树专题进阶题解析(后7道)
  • 23种设计模式简介
  • Liunx(CentOS-6-x86_64)使用Nginx部署Vue项目
  • VUE3开发-9、axios前后端跨域问题解决方案
  • 英语学习(GitHub学到的分享)
  • 滑动窗口算法-day7(越长越合法子数组)
  • 18、函数的反柯里化
  • SpringMVC 基本概念与代码示例
  • 【git】 贮藏 stash
  • 《 C++ 点滴漫谈: 三十 》高手写 C++,参数这样传才高效!你真的用对了吗?
  • 【git】删除已加入 .gitignore却仍被git追踪的文件
  • 1分钟看懂React的那些Hook‘s
  • java每日精进 3.11 【多租户】
  • 【性能测试】Jmeter详细操作-小白使用手册(2)
  • win10安装部署DB-gpt,坑多