当前位置: 首页 > news >正文

【深度学习】实验 — 动手实现 GPT【二】:注意力机制、注意力掩码、多头注意力机制

【深度学习】实验 — 动手实现 GPT【二】:注意力机制、多头注意力机制

  • 注意力机制
    • 简单示例:单个元素的情况
    • 简单示例:计算所有输入词元的注意力权重
      • 推广到所有输入序列词元:
  • 注意力掩码
  • 代码实现多头注意力
  • 测试

注意力机制

简单示例:单个元素的情况

  • 假设我们有以下输入句子,已按照第 3 章中的描述嵌入为 3 维向量(此处使用非常小的嵌入维度,仅用于说明,方便在页面上显示而不换行):
inputs = torch.tensor([[0.43, 0.15, 0.89], # Your     (x^1)[0.55, 0.87, 0.66], # journey  (x^2)[0.57, 0.85, 0.64], # starts   (x^3)[0.22, 0.58, 0.33], # with     (x^4)[0.77, 0.25, 0.10], # one      (x^5)[0.05, 0.80, 0.55]] # step     (x^6)
)
  • (在本书中,我们遵循机器学习和深度学习的常见惯例,即训练样本表示为行,特征值表示为列;在上面的张量中,每一行表示一个词,每一列表示一个嵌入维度。)

  • 本节的主要目的是演示如何使用第二个输入序列 x ( 2 ) x^{(2)} x(2) 作为查询,计算上下文向量 z ( 2 ) z^{(2)} z(2)

  • 图示展示了该过程的初始步骤,其中通过点积操作计算 x ( 2 ) x^{(2)} x(2) 与所有其他输入元素之间的注意力分数 ω。

请添加图片描述

  • 我们使用输入序列中的元素 2,即 x ( 2 ) x^{(2)} x(2),作为示例来计算上下文向量 z ( 2 ) z^{(2)} z(2);在本节稍后,我们将推广此方法来计算所有的上下文向量。
  • 第一步是通过计算查询 x ( 2 ) x^{(2)} x(2) 与所有其他输入词元之间的点积,得到未归一化的注意力分数:
query = inputs[1]  # 2nd input token is the queryattn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)print(attn_scores_2)

输出

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
  • 步骤 2: 将未归一化的注意力分数(“omegas”, ω \omega ω)归一化,使其总和为 1。
  • 以下是一种简单的归一化方法,使未归一化的注意力分数总和为 1(这种方式是约定俗成的,有助于解释,并对训练稳定性非常重要):

请添加图片描述

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

输出

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)
  • 然而,在实际操作中,通常推荐使用 softmax 函数进行归一化,因为它在处理极端值方面更有效,并且在训练过程中具有更理想的梯度特性。
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

输出

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
  • 步骤 3:通过将嵌入的输入词元 x ( i ) x^{(i)} x(i) 与注意力权重相乘,并将所得向量求和,计算上下文向量 z ( 2 ) z^{(2)} z(2)请添加图片描述
query = inputs[1] # 2nd input token is the querycontext_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):context_vec_2 += attn_weights_2[i]*x_iprint(context_vec_2)

输出

tensor([0.4419, 0.6515, 0.5683])

简单示例:计算所有输入词元的注意力权重

推广到所有输入序列词元:

  • 上面我们计算了输入 2 的注意力权重和上下文向量。

  • 接下来,我们将推广该计算,以求得所有的注意力权重和上下文向量。
    请添加图片描述

  • (请注意,此图中的数字已截取至小数点后两位,以减少视觉杂乱;每行的值应相加为 1.0 或 100%;同样,其他图中的数字也被截取。)

  • 在自注意力机制中,首先计算注意力分数,随后对其进行归一化以得出总和为 1 的注意力权重。

  • 然后,这些注意力权重被用于通过输入的加权求和生成上下文向量。

请添加图片描述

  • 将之前的步骤 1应用于所有成对元素,以计算未归一化的注意力分数矩阵:
attn_scores = torch.empty(6, 6)for i, x_i in enumerate(inputs):for j, x_j in enumerate(inputs):attn_scores[i, j] = torch.dot(x_i, x_j)print(attn_scores)

输出

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
  • 我们可以通过矩阵乘法更高效地实现上述计算:
attn_scores = inputs @ inputs.T
print(attn_scores)

输出

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
  • 与之前的步骤 2类似,我们对每一行进行归一化,使每一行的值相加为 1:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

输出

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
  • 应用之前的步骤 3来计算所有上下文向量:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

输出

tensor([[0.4421, 0.5931, 0.5790],[0.4419, 0.6515, 0.5683],[0.4431, 0.6496, 0.5671],[0.4304, 0.6298, 0.5510],[0.4671, 0.5910, 0.5266],[0.4177, 0.6503, 0.5645]])

注意力掩码

  • 模型在序列中某一位置的预测仅依赖于之前位置的已知输出,而不依赖未来位置的输出。
  • 简单来说,这确保了每个下一个词的预测仅依赖于前面的词。
  • 为了实现这一点,对于每个给定词元,我们将未来的词元(即在当前词元之后的词元)进行掩码处理:
    请添加图片描述
attn_weights

输出

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
  • 最简单的方式是通过 PyTorch 的 tril 函数创建一个掩码,将主对角线下方的元素(包括主对角线)设置为 1,主对角线上方的元素设置为 0,以掩盖未来的注意力权重:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],[1., 1., 0., 0., 0., 0.],[1., 1., 1., 0., 0., 0.],[1., 1., 1., 1., 0., 0.],[1., 1., 1., 1., 1., 0.],[1., 1., 1., 1., 1., 1.]])
  • 然后,我们可以将注意力权重与此掩码相乘,以将对角线上方的注意力分数置为零:
masked_simple = attn_weights*mask_simple
print(masked_simple)
tensor([[0.2098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.1385, 0.2379, 0.0000, 0.0000, 0.0000, 0.0000],[0.1390, 0.2369, 0.2326, 0.0000, 0.0000, 0.0000],[0.1435, 0.2074, 0.2046, 0.1462, 0.0000, 0.0000],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.0000],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
  • 然而,如果在 softmax 之后应用掩码(如上所述),会破坏 softmax 创建的概率分布。
  • Softmax 确保所有输出值的总和为 1。
  • 在 softmax 之后进行掩码处理则需要重新归一化输出以再次使其总和为 1,这会使过程复杂化,并可能导致意想不到的效果。
  • 为确保每行的总和为 1,我们可以按如下方式归一化注意力权重:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
  • 让我们简单了解一种更高效的方法来实现上述目标。
  • 因此,与其将对角线上方的注意力权重置零并重新归一化结果,我们可以在未归一化的注意力分数进入 softmax 函数之前,将对角线上方的分数掩码为负无穷大。请添加图片描述
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

输出

tensor([[0.9995,   -inf,   -inf,   -inf,   -inf,   -inf],[0.9544, 1.4950,   -inf,   -inf,   -inf,   -inf],[0.9422, 1.4754, 1.4570,   -inf,   -inf,   -inf],[0.4753, 0.8434, 0.8296, 0.4937,   -inf,   -inf],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654,   -inf],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
  • 如下所示,现在每行的注意力权重再次正确地总和为 1:
attn_weights = torch.softmax(masked, dim=-1)
print(attn_weights)

输出

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.3680, 0.6320, 0.0000, 0.0000, 0.0000, 0.0000],[0.2284, 0.3893, 0.3822, 0.0000, 0.0000, 0.0000],[0.2046, 0.2956, 0.2915, 0.2084, 0.0000, 0.0000],[0.1753, 0.2250, 0.2269, 0.1570, 0.2158, 0.0000],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

代码实现多头注意力

class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()assert (d_out % num_heads == 0), \"d_out must be divisible by num_heads"self.d_out = d_outself.num_heads = num_headsself.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dimself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputsself.dropout = nn.Dropout(dropout)self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length),diagonal=1))def forward(self, x):b, num_tokens, d_in = x.shapekeys = self.W_key(x) # Shape: (b, num_tokens, d_out)queries = self.W_query(x)values = self.W_value(x)# We implicitly split the matrix by adding a `num_heads` dimension# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)values = values.view(b, num_tokens, self.num_heads, self.head_dim)queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)keys = keys.transpose(1, 2)queries = queries.transpose(1, 2)values = values.transpose(1, 2)# Compute scaled dot-product attention (aka self-attention) with a causal maskattn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head# Original mask truncated to the number of tokens and converted to booleanmask_bool = self.mask.bool()[:num_tokens, :num_tokens]# Use the mask to fill attention scoresattn_scores.masked_fill_(mask_bool, -torch.inf)attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)# Shape: (b, num_tokens, num_heads, head_dim)context_vec = (attn_weights @ values).transpose(1, 2)# Combine heads, where self.d_out = self.num_heads * self.head_dimcontext_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)context_vec = self.out_proj(context_vec) # optional projectionreturn context_vec

测试

batch = torch.stack((inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)context_vecs = mha(batch)print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

输出

tensor([[[-0.6033, -0.2785],[-0.5409, -0.2509],[-0.5241, -0.2439],[-0.4974, -0.2357],[-0.5224, -0.2520],[-0.4887, -0.2361]],[[-0.6033, -0.2785],[-0.5409, -0.2509],[-0.5241, -0.2439],[-0.4974, -0.2357],[-0.5224, -0.2520],[-0.4887, -0.2361]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
  • 另外请注意,我们在上面的 MultiHeadAttention 类中添加了一个线性投影层 (self.out_proj)。这只是一个不会改变维度的线性变换。在大型语言模型的实现中,使用这样的投影层是一个标准惯例,但并非绝对必要(最近的研究表明,移除该层不会影响模型性能);

http://www.mrgr.cn/news/62612.html

相关文章:

  • stm32工程建立流程(没有标准库,寄存器编写程序)
  • Cocos使用精灵组件显示相机内容
  • 【JavaEE】【多线程】进阶知识
  • 107. 阴影范围.shadow.camera
  • Scikit-learn和Keras简介
  • 隨筆 20241024 Kafka中的ISR列表:分区副本的族谱
  • ABAP RFC SQL 模糊查询和多个区间条件
  • 一些老程序员不愿透露的工作小技巧…
  • 【HDRP下实现视差效果_CubeMap和九宫格ArrayMap形式】
  • 2024年“炫转青春”山东省飞盘联赛盛大开赛——临沭县青少年飞盘运动迅速升温
  • 隐私保护下的数据提取策略
  • USC H5S支持大华ICC平台对接
  • QT:QThread:重写run函数
  • python函数连续
  • ARM base instruction -- adc
  • 2181、合并零之间的节点
  • YOLOv4和Darknet实现坑洼检测
  • 如何成为一名优秀的程序员,进来看看
  • 网络安全不知道怎么学,看完这篇,中学生都能学会
  • iOS 再谈KVC、 KVO
  • 阿里CDN框架
  • 前端实现echarts折线图堆叠(多条折线)
  • Jupyter notebook 添加目录插件
  • 一致校验矩阵计算
  • kdd比赛方案
  • 基于Python的PostgreSQL数据库操作示例(三)