【深度学习】实验 — 动手实现 GPT【二】:注意力机制、注意力掩码、多头注意力机制
【深度学习】实验 — 动手实现 GPT【二】:注意力机制、多头注意力机制
- 注意力机制
- 简单示例:单个元素的情况
- 简单示例:计算所有输入词元的注意力权重
- 推广到所有输入序列词元:
- 注意力掩码
- 代码实现多头注意力
- 测试
注意力机制
简单示例:单个元素的情况
- 假设我们有以下输入句子,已按照第 3 章中的描述嵌入为 3 维向量(此处使用非常小的嵌入维度,仅用于说明,方便在页面上显示而不换行):
inputs = torch.tensor([[0.43, 0.15, 0.89], # Your (x^1)[0.55, 0.87, 0.66], # journey (x^2)[0.57, 0.85, 0.64], # starts (x^3)[0.22, 0.58, 0.33], # with (x^4)[0.77, 0.25, 0.10], # one (x^5)[0.05, 0.80, 0.55]] # step (x^6)
)
-
(在本书中,我们遵循机器学习和深度学习的常见惯例,即训练样本表示为行,特征值表示为列;在上面的张量中,每一行表示一个词,每一列表示一个嵌入维度。)
-
本节的主要目的是演示如何使用第二个输入序列 x ( 2 ) x^{(2)} x(2) 作为查询,计算上下文向量 z ( 2 ) z^{(2)} z(2)。
-
图示展示了该过程的初始步骤,其中通过点积操作计算 x ( 2 ) x^{(2)} x(2) 与所有其他输入元素之间的注意力分数 ω。
- 我们使用输入序列中的元素 2,即 x ( 2 ) x^{(2)} x(2),作为示例来计算上下文向量 z ( 2 ) z^{(2)} z(2);在本节稍后,我们将推广此方法来计算所有的上下文向量。
- 第一步是通过计算查询 x ( 2 ) x^{(2)} x(2) 与所有其他输入词元之间的点积,得到未归一化的注意力分数:
query = inputs[1] # 2nd input token is the queryattn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)print(attn_scores_2)
输出
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
- 步骤 2: 将未归一化的注意力分数(“omegas”, ω \omega ω)归一化,使其总和为 1。
- 以下是一种简单的归一化方法,使未归一化的注意力分数总和为 1(这种方式是约定俗成的,有助于解释,并对训练稳定性非常重要):
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())
输出
Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)
- 然而,在实际操作中,通常推荐使用 softmax 函数进行归一化,因为它在处理极端值方面更有效,并且在训练过程中具有更理想的梯度特性。
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
输出
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
- 步骤 3:通过将嵌入的输入词元 x ( i ) x^{(i)} x(i) 与注意力权重相乘,并将所得向量求和,计算上下文向量 z ( 2 ) z^{(2)} z(2):
query = inputs[1] # 2nd input token is the querycontext_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):context_vec_2 += attn_weights_2[i]*x_iprint(context_vec_2)
输出
tensor([0.4419, 0.6515, 0.5683])
简单示例:计算所有输入词元的注意力权重
推广到所有输入序列词元:
-
上面我们计算了输入 2 的注意力权重和上下文向量。
-
接下来,我们将推广该计算,以求得所有的注意力权重和上下文向量。
-
(请注意,此图中的数字已截取至小数点后两位,以减少视觉杂乱;每行的值应相加为 1.0 或 100%;同样,其他图中的数字也被截取。)
-
在自注意力机制中,首先计算注意力分数,随后对其进行归一化以得出总和为 1 的注意力权重。
-
然后,这些注意力权重被用于通过输入的加权求和生成上下文向量。
- 将之前的步骤 1应用于所有成对元素,以计算未归一化的注意力分数矩阵:
attn_scores = torch.empty(6, 6)for i, x_i in enumerate(inputs):for j, x_j in enumerate(inputs):attn_scores[i, j] = torch.dot(x_i, x_j)print(attn_scores)
输出
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
- 我们可以通过矩阵乘法更高效地实现上述计算:
attn_scores = inputs @ inputs.T
print(attn_scores)
输出
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
- 与之前的步骤 2类似,我们对每一行进行归一化,使每一行的值相加为 1:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
输出
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
- 应用之前的步骤 3来计算所有上下文向量:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
输出
tensor([[0.4421, 0.5931, 0.5790],[0.4419, 0.6515, 0.5683],[0.4431, 0.6496, 0.5671],[0.4304, 0.6298, 0.5510],[0.4671, 0.5910, 0.5266],[0.4177, 0.6503, 0.5645]])
注意力掩码
- 模型在序列中某一位置的预测仅依赖于之前位置的已知输出,而不依赖未来位置的输出。
- 简单来说,这确保了每个下一个词的预测仅依赖于前面的词。
- 为了实现这一点,对于每个给定词元,我们将未来的词元(即在当前词元之后的词元)进行掩码处理:
attn_weights
输出
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
- 最简单的方式是通过 PyTorch 的
tril
函数创建一个掩码,将主对角线下方的元素(包括主对角线)设置为 1,主对角线上方的元素设置为 0,以掩盖未来的注意力权重:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],[1., 1., 0., 0., 0., 0.],[1., 1., 1., 0., 0., 0.],[1., 1., 1., 1., 0., 0.],[1., 1., 1., 1., 1., 0.],[1., 1., 1., 1., 1., 1.]])
- 然后,我们可以将注意力权重与此掩码相乘,以将对角线上方的注意力分数置为零:
masked_simple = attn_weights*mask_simple
print(masked_simple)
tensor([[0.2098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.1385, 0.2379, 0.0000, 0.0000, 0.0000, 0.0000],[0.1390, 0.2369, 0.2326, 0.0000, 0.0000, 0.0000],[0.1435, 0.2074, 0.2046, 0.1462, 0.0000, 0.0000],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.0000],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
- 然而,如果在 softmax 之后应用掩码(如上所述),会破坏 softmax 创建的概率分布。
- Softmax 确保所有输出值的总和为 1。
- 在 softmax 之后进行掩码处理则需要重新归一化输出以再次使其总和为 1,这会使过程复杂化,并可能导致意想不到的效果。
- 为确保每行的总和为 1,我们可以按如下方式归一化注意力权重:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
- 让我们简单了解一种更高效的方法来实现上述目标。
- 因此,与其将对角线上方的注意力权重置零并重新归一化结果,我们可以在未归一化的注意力分数进入 softmax 函数之前,将对角线上方的分数掩码为负无穷大。
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
输出
tensor([[0.9995, -inf, -inf, -inf, -inf, -inf],[0.9544, 1.4950, -inf, -inf, -inf, -inf],[0.9422, 1.4754, 1.4570, -inf, -inf, -inf],[0.4753, 0.8434, 0.8296, 0.4937, -inf, -inf],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, -inf],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
- 如下所示,现在每行的注意力权重再次正确地总和为 1:
attn_weights = torch.softmax(masked, dim=-1)
print(attn_weights)
输出
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.3680, 0.6320, 0.0000, 0.0000, 0.0000, 0.0000],[0.2284, 0.3893, 0.3822, 0.0000, 0.0000, 0.0000],[0.2046, 0.2956, 0.2915, 0.2084, 0.0000, 0.0000],[0.1753, 0.2250, 0.2269, 0.1570, 0.2158, 0.0000],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
代码实现多头注意力
class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()assert (d_out % num_heads == 0), \"d_out must be divisible by num_heads"self.d_out = d_outself.num_heads = num_headsself.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dimself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputsself.dropout = nn.Dropout(dropout)self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length),diagonal=1))def forward(self, x):b, num_tokens, d_in = x.shapekeys = self.W_key(x) # Shape: (b, num_tokens, d_out)queries = self.W_query(x)values = self.W_value(x)# We implicitly split the matrix by adding a `num_heads` dimension# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)values = values.view(b, num_tokens, self.num_heads, self.head_dim)queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)keys = keys.transpose(1, 2)queries = queries.transpose(1, 2)values = values.transpose(1, 2)# Compute scaled dot-product attention (aka self-attention) with a causal maskattn_scores = queries @ keys.transpose(2, 3) # Dot product for each head# Original mask truncated to the number of tokens and converted to booleanmask_bool = self.mask.bool()[:num_tokens, :num_tokens]# Use the mask to fill attention scoresattn_scores.masked_fill_(mask_bool, -torch.inf)attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)# Shape: (b, num_tokens, num_heads, head_dim)context_vec = (attn_weights @ values).transpose(1, 2)# Combine heads, where self.d_out = self.num_heads * self.head_dimcontext_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)context_vec = self.out_proj(context_vec) # optional projectionreturn context_vec
测试
batch = torch.stack((inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)context_vecs = mha(batch)print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
输出
tensor([[[-0.6033, -0.2785],[-0.5409, -0.2509],[-0.5241, -0.2439],[-0.4974, -0.2357],[-0.5224, -0.2520],[-0.4887, -0.2361]],[[-0.6033, -0.2785],[-0.5409, -0.2509],[-0.5241, -0.2439],[-0.4974, -0.2357],[-0.5224, -0.2520],[-0.4887, -0.2361]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
- 另外请注意,我们在上面的
MultiHeadAttention
类中添加了一个线性投影层 (self.out_proj
)。这只是一个不会改变维度的线性变换。在大型语言模型的实现中,使用这样的投影层是一个标准惯例,但并非绝对必要(最近的研究表明,移除该层不会影响模型性能);