llama3 implemented from scratch 笔记
github地址:https://github.com/naklecha/llama3-from-scratch?tab=readme-ov-file
分词器的实现
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json
import matplotlib.pyplot as plttokenizer_path = "Meta-Llama-3-8B/tokenizer.model"
special_tokens = ["<|begin_of_text|>","<|end_of_text|>","<|reserved_special_token_0|>","<|reserved_special_token_1|>","<|reserved_special_token_2|>","<|reserved_special_token_3|>","<|start_header_id|>","<|end_header_id|>","<|reserved_special_token_4|>","<|eot_id|>", # end of turn] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(name=Path(tokenizer_path).name,pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",mergeable_ranks=mergeable_ranks,special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)tokenizer.decode(tokenizer.encode("hello world!"))
读取模型文件
model = torch.load("Meta-Llama-3-8B/consolidated.00.pth")
print(json.dumps(list(model.keys())[:20], indent=4))
["tok_embeddings.weight","layers.0.attention.wq.weight","layers.0.attention.wk.weight","layers.0.attention.wv.weight","layers.0.attention.wo.weight","layers.0.feed_forward.w1.weight","layers.0.feed_forward.w3.weight","layers.0.feed_forward.w2.weight","layers.0.attention_norm.weight","layers.0.ffn_norm.weight","layers.1.attention.wq.weight","layers.1.attention.wk.weight","layers.1.attention.wv.weight","layers.1.attention.wo.weight","layers.1.feed_forward.w1.weight","layers.1.feed_forward.w3.weight","layers.1.feed_forward.w2.weight","layers.1.attention_norm.weight","layers.1.ffn_norm.weight","layers.2.attention.wq.weight"
]
with open("Meta-Llama-3-8B/params.json", "r") as f:config = json.load(f)
config
{'dim': 4096,'n_layers': 32,'n_heads': 32,'n_kv_heads': 8,'vocab_size': 128256,'multiple_of': 1024,'ffn_dim_multiplier': 1.3,'norm_eps': 1e-05,'rope_theta': 500000.0}
dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])
将文本转换为 tokens(这里没有手动实现分词器)
这里用 tiktoken 作为 tokenizer
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = [128000] + tokenizer.encode(prompt)
print(tokens)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(prompt_split_as_tokens)
[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
将令牌嵌入(这里用的内置的神经网络模块,也没有手动实现)
总之,[17, 1]
的 tokens 现在变成了 [17, 4096]
的嵌入向量
embedding_layer = torch.nn.Embedding(vocab_size, dim)
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
token_embeddings_unnormalized.shape
torch.Size([17, 4096])
使用均方根 RMS 对嵌入进行归一化
这里并不会进行形状的改变,值只是进行了归一化,为了防止除以零的情况,会设置一个 norm_eps
。
# def rms_norm(tensor, norm_weights):
# rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5
# return tensor * (norm_weights / rms)
def rms_norm(tensor, norm_weights):return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
tensor.pow(2)
:
这一步将输入张量 tensor 中的每个元素进行平方操作。假设 tensor
的形状为 (batch_size, seq_len, hidden_dim)
,那么 tensor.pow(2)
的结果形状仍然是 (batch_size, seq_len, hidden_dim)
,但每个元素都被平方了。
tensor.pow(2).mean(-1, keepdim=True)
:
这一步计算张量在最后一个维度(即 hidden_dim
维度)上的均值。mean(-1, keepdim=True)
表示在最后一个维度上求均值,并且保持该维度的形状(即 keepdim=True
)。结果的形状为 (batch_size, seq_len, 1)
。
tensor.pow(2).mean(-1, keepdim=True) + norm_eps
:
这一步在均值的基础上加上一个小的常数 norm_eps
,以避免除零错误。norm_eps
通常是一个非常小的正数,例如 1e-8。
torch.rsqrt(...)
:
torch.rsqrt
是平方根的倒数(即 1 / sqrt(x)
)。这一步计算的是 1 / sqrt(mean + norm_eps)
,即 RMS 值的倒数。
tensor * torch.rsqrt(...)
:
这一步将输入张量 tensor
乘以 RMS
值的倒数,从而实现归一化。归一化后的张量在最后一个维度上的 RMS
值为1。
* norm_weights
:
最后,将归一化后的张量乘以 norm_weights
。norm_weights
是一个可学习的权重张量,形状为 (hidden_dim,)
,用于对归一化后的特征进行缩放。
通常,归一化操作会将特征缩放到一个固定的范围,然而,不同的特征可能需要不同的缩放因子来更好地适应模型的需求。通过引入可学习的权重,模型可以根据数据的特点和任务的需求,自动调整每个特征的缩放因子。
构建 transformer 的第一层
归一化
# 这里是attention之前的normalization
token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"])
token_embeddings.shape
torch.Size([17, 4096])
手动实现注意力
从模型中加载查询(query)、键(key)、值(value)和输出(output)向量时,我们注意到它们的形状分别是 [4096x4096]、[1024x4096]、[1024x4096]、[4096x4096]。
假设我们有以下形状的矩阵:
query_matrix: [4096x4096]
key_matrix: [1024x4096]
value_matrix: [1024x4096]
output_matrix: [4096x4096]
我们可以通过以下方式解开它们:
解开查询
q_layer0 = model["layers.0.attention.wq.weight"]
head_dim = q_layer0.shape[0] // n_heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)
q_layer0.shape
torch.Size([32, 128, 4096])
32 是 llama3
的注意力头的数量,128 是查询向量的大小,4096 是令牌嵌入的大小。
实现第一层的第一个头
查询权重矩阵的大小是 [128, 4096]
q_layer0_head0 = q_layer0[0]
q_layer0_head0.shape
torch.Size([128, 4096])
现在将查询权重矩阵和令牌嵌入相乘,以接收对令牌的查询
最终的形状是 [17, 128]
,这是因为有 17 个令牌,和 128 长度的查询。
q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)
q_per_token.shape
torch.Size([17, 128])
位置编码
当前阶段是,我们为提示(prompt)中的每个令牌都有一个查询向量,但是单独的查询向量并不知道它在提示中的位置,在例子中,使用了三次 “the” 标记的查询向量([1, 128]
)。使用 RoPE 旋转位置编码来执行这些旋转。
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_split_into_pairs.shape
torch.Size([17, 64, 2])
这一步将查询向量分成对,并对每对应用旋转角度偏移。
用复数的点积来旋转向量
# 生成一个从0到1的等间隔序列,分成64个部分。这个序列表示每个部分的归一化位置
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
# 计算频率freqs,这里的rope_theta是llama3给的500000.0
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)# 生成一个 [17, 64] 的矩阵,其中每一行对应一个标记的频率。torch.outer函数计算两个向量的外积,生成一个矩阵
freqs_for_each_token = torch.outer(torch.arange(17), freqs)
# 将频率转换为复数形式,其中实部为1,虚部为频率。torch.polar函数生成复数形式的向量,其中模为1,相位为频率
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
freqs_cis.shape
等间隔序列 zero_to_one_split_into_64_parts:
tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,0.9844])
频率:
tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])
z = r ⋅ e i θ z=r \cdot e^{i \theta} z=r⋅eiθ表示一个旋转角度为 θ \theta θ的复数
旋转矩阵中的每一个元素freqs_cis[i,j]
可以表示为 e i ⋅ f r e q s _ f o r _ e a c h _ t o k e n [ i , j ] e ^{i⋅{freqs\_for\_each\_token[i,j]}} ei⋅freqs_for_each_token[i,j],其中 i i i是标记的索引, j j j是频率的索引。
这就是所有 token 对应的旋转矩阵,下面进行相乘得到旋转后的所有 token
的查询
现在我们有了每个 token 查询的复数(角度变化向量)
我们可以将我们的查询转换为复数然后进行点积以根据位置旋转查询。
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
q_per_token_as_complex_numbers.shape
torch.Size([17, 64])
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
q_per_token_as_complex_numbers_rotated.shape
torch.Size([17, 64])
这样就是旋转后的查询。
得到旋转后的向量之后
通过将查询再次从复数看成实数(从[a+bj]
的存储形式变成[a, b]
),可以得到
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)
q_per_token_split_into_pairs_rotated.shape
torch.Size([17, 64, 2])
旋转后的对现在已经合并,我们现在有了一个新的查询向量,其形状是[17, 128]
。
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
q_per_token_rotated.shape
torch.Size([17, 128])
键,几乎和查询的处理是一样的
键也生成维度为 128 的键向量。键的权重数量只有查询(queries)的 1/4,这是因为键的权重在 4 个注意力头之间共享,以减少所需的计算量。键也像查询一样旋转以添加位置信息,因为同样的原因。
k_layer0 = model["layers.0.attention.wk.weight"]
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)
k_layer0.shape
torch.Size([8, 128, 4096])
k_layer0_head0 = k_layer0[0]
k_layer0_head0.shape
torch.Size([128, 4096])
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)
k_per_token.shape
torch.Size([17, 128])
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
k_per_token_split_into_pairs.shape
torch.Size([17, 64, 2])
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
k_per_token_as_complex_numbers.shape
torch.Size([17, 64])
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
k_per_token_split_into_pairs_rotated.shape
torch.Size([17, 64, 2])
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
k_per_token_rotated.shape
torch.Size([17, 128])
在这个阶段,现在有每个令牌的查询和键的旋转值
下一步,把查询和键相乘
这样做会给我们一个分数,将每个标记与其他标记进行映射。这个分数描述了每个标记的查询与每个标记的键之间的关系。这就是自注意力机制(Self-Attention)😃
注意力分数矩阵(qk_per_token)的形状为 [17x17],其中 17 是提示中的标记数量。
详细解释
在自注意力机制中,我们通过计算查询(queries)和键(keys)之间的点积来生成注意力分数。注意力分数矩阵描述了每个标记与其他标记之间的关系。
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5
qk_per_token.shape
torch.Size([17, 17])
现在我们必须 mask 查询键分数
在训练过程中,Llama3 的未来标记的 qk
分数被掩码。
为什么?因为在训练过程中,我们只使用过去的标记来预测未来的标记。
因此,在推理过程中,我们将未来的标记设置为零。
# 显示注意力分数矩阵的热力图
def display_qk_heatmap(qk_per_token):_, ax = plt.subplots()# 生成热力图,使用 `viridis` 颜色映射im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')ax.set_xticks(range(len(prompt_split_as_tokens)))ax.set_yticks(range(len(prompt_split_as_tokens)))ax.set_xticklabels(prompt_split_as_tokens)ax.set_yticklabels(prompt_split_as_tokens)ax.figure.colorbar(im, ax=ax)display_qk_heatmap(qk_per_token)
# 生成一个掩码矩阵, 初始都为-inf
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
# 将掩码矩阵转换为上三角矩阵,diagonal=1保留对角线下一个元素及其以上的元素,其余为0
mask = torch.triu(mask, diagonal=1)
mask
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
qk_per_token_after_masking = qk_per_token + mask
display_qk_heatmap(qk_per_token_after_masking)
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
display_qk_heatmap(qk_per_token_after_masking_after_softmax)
Values
值权重在每 4 个注意力头(所以总共 8 个注意力头)之间共享,以节省计算量。这意味着每个注意力头使用相同的值权重矩阵。
v_layer0 = model["layers.0.attention.wv.weight"]
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)
v_layer0.shape
torch.Size([8, 128, 4096])
第一层, 第一个权重矩阵为:
v_layer0_head0 = v_layer0[0]
v_layer0_head0.shape
torch.Size([128, 4096])
值向量
我们现在使用值权重来获取每个标记的注意力值,其大小为 [17x128]
,其中 17 是提示中的标记数量,128 是每个标记的值向量的维度。
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)
v_per_token.shape
torch.Size([17, 128])
注意力
在自注意力机制中,我们将注意力分数矩阵与值矩阵相乘,生成最终的注意力输出。注意力输出的形状为 [17x128]
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention.shape
torch.Size([17, 128])
多头注意力
我们现在有了第一层和第一个注意力头的注意力值。现在,我将运行一个循环,对第一层的每个注意力头执行与上述单元格相同的数学运算。
qkv_attention_store = []for head in range(n_heads):q_layer0_head = q_layer0[head]k_layer0_head = k_layer0[head//4] # key weights are shared across 4 headsv_layer0_head = v_layer0[head//4] # value weights are shared across 4 headsq_per_token = torch.matmul(token_embeddings, q_layer0_head.T)k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)mask = torch.triu(mask, diagonal=1)qk_per_token_after_masking = qk_per_token + maskqk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)qkv_attention_store.append(qkv_attention)len(qkv_attention_store)
32
我们现在有了第一层所有 32 个注意力头的 qkv_attention 矩阵。接下来,把所有注意力分数合并成一个大小为 [17x4096] 的大矩阵。
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
stacked_qkv_attention.shape
torch.Size([17, 4096])
权矩阵,最后一个步骤
w_layer0 = model["layers.0.attention.wo.weight"]
w_layer0.shape
在完成第 0 层注意力机制的最后一步是,将注意力输出与权重矩阵相乘。具体来说,我们将最终的注意力输出矩阵与权重矩阵相乘,生成最终的注意力输出。
torch.Size([4096, 4096])
这是一个简单的线性层,所以我们只需要进行矩阵乘法(matmul)。
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
embedding_delta.shape
torch.Size([17, 4096])
我们现在有了注意力机制之后的嵌入值变化,这应该加到原始的标记嵌入值上。
embedding_after_edit = token_embeddings_unnormalized + embedding_delta
embedding_after_edit.shape
torch.Size([17, 4096])
我们将其归一化然后运行一个前馈神经网络通过嵌入 δ \delta δ
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])
embedding_after_edit_normalized.shape
torch.Size([17, 4096])
在加载前馈网络(Feed-Forward Network, FFN)的权重并实现前馈网络时,我们需要执行以下步骤:
在 Llama3 中,他们使用了 SwiGLU 前馈网络。这种网络架构在模型需要时能够很好地添加非线性。如今,在大型语言模型(LLMs)中使用这种前馈网络架构是非常标准的。
w1 = model["layers.0.feed_forward.w1.weight"]
w2 = model["layers.0.feed_forward.w2.weight"]
w3 = model["layers.0.feed_forward.w3.weight"]
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
output_after_feedforward.shape
torch.Size([17, 4096])
在 Llama3 中,前馈网络使用了 SwiGLU 架构。具体来说,前馈网络由三个线性层组成,其中第一个线性层的输出通过 Swish 激活函数,然后与第三个线性层的输出相乘,最后通过第二个线性层生成新的嵌入值。
Swish 激活函数:
Swish 激活函数是一种平滑的非线性函数,定义为:
S w i s h ( x ) = x ⋅ σ ( β x ) Swish(x)=x\cdot \sigma(\beta x) Swish(x)=x⋅σ(βx)
其中, σ \sigma σ 是 sigmoid 函数, β \beta β 是一个可学习的参数(通常设为 1). Swish 激活函数在许多情况下表现优于 ReLU 和其他常见的激活函数.
GLU (Gated Linear Unit):
GLU 是一种门控机制,用于控制信息的流动。GLU 的定义为:
G L U ( a , b ) = a ⋅ σ ( b ) GLU(a, b)=a\cdot \sigma(b) GLU(a,b)=a⋅σ(b)
其中, a a a 和 b b b 是两个线性变换的输出, σ \sigma σ是 sigmoid 函数. GLU 通过门控信号 σ ( b ) \sigma(b) σ(b)来控制 a a a 的信息流动.
SwiGLU 结构:
SwiGLU 结合了 Swish 激活函数和 GLU 结构,定义为:
S w i G L U ( x , W 1 , W 2 , W 3 ) = S w i s h ( x W 1 ) ⋅ σ ( x W 3 ) W 2 SwiGLU(x, W_1, W_2, W_3)=Swish(xW_1)\cdot \sigma(xW_3)W_2 SwiGLU(x,W1,W2,W3)=Swish(xW1)⋅σ(xW3)W2
我们终于在第一层之后为每个标记生成了新的编辑后的嵌入值。
在自注意力机制和前馈网络之后,我们为每个标记生成了新的嵌入值。这些新的嵌入值包含了更多的上下文信息,从而提高了模型的性能和理解能力。
在完成之前,我们还有 31 层要处理(只需要一个循环)。
你可以想象这个编辑后的嵌入值包含了第一层中所有查询的信息。现在,每一层都会对提出的问题进行越来越复杂的编码,直到我们有一个嵌入值,它包含了我们需要了解的关于下一个标记的所有信息。
layer_0_embedding = embedding_after_edit+output_after_feedforward
layer_0_embedding.shape
torch.Size([17, 4096])
总和
final_embedding = token_embeddings_unnormalized
for layer in range(n_layers):qkv_attention_store = []layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])q_layer = model[f"layers.{layer}.attention.wq.weight"]q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)k_layer = model[f"layers.{layer}.attention.wk.weight"]k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)v_layer = model[f"layers.{layer}.attention.wv.weight"]v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)w_layer = model[f"layers.{layer}.attention.wo.weight"]for head in range(n_heads):q_layer_head = q_layer[head]k_layer_head = k_layer[head//4]v_layer_head = v_layer[head//4]q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))mask = torch.triu(mask, diagonal=1)qk_per_token_after_masking = qk_per_token + maskqk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)qkv_attention_store.append(qkv_attention)stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)w_layer = model[f"layers.{layer}.attention.wo.weight"]embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)embedding_after_edit = final_embedding + embedding_deltaembedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])w1 = model[f"layers.{layer}.feed_forward.w1.weight"]w2 = model[f"layers.{layer}.feed_forward.w2.weight"]w3 = model[f"layers.{layer}.feed_forward.w3.weight"]output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)final_embedding = embedding_after_edit+output_after_feedforward
现在我们有了最终的嵌入, 这是模型对下一个令牌的最好猜测
嵌入的形状和常规令牌的形状相同 [ 17 , 4096 ] [17, 4096] [17,4096].
final_embedding = rms_norm(final_embedding, model["norm.weight"])
final_embedding.shape
torch.Size([17, 4096])
最后, 将嵌入解码成令牌值
我们将使用输出解码器将最终嵌入解码成令牌
model["output.weight"].shape
torch.Size([128256, 4096])
我们使用最后一个标记的嵌入值来预测下一个值。
根据《银河系漫游指南》这本书,42 是“生命、宇宙以及一切的终极问题的答案”。所以大多数 LLMs 在这里都会回答 42.
# 通过线性层生成 logits 向量, 训练过程中隐式调用了 softmax
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
logits.shape
torch.Size([128256])
预测的 token number 是 2983, 解码后是 42
next_token = torch.argmax(logits, dim=-1)
next_token
tensor(2983)
tokenizer.decode([next_token.item()])
'42'