deepseek v3网络结构源码分析笔记
1.网络主结构代码:主要是循环n_layers个TransformerBlock,在self.layers内构建
class Transformer(nn.Module):def __init__(self, args: ModelArgs):global world_size, rankworld_size = dist.get_world_size() if dist.is_initialized() else 1rank = dist.get_rank() if dist.is_initialized() else 0Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16super().__init__()self.max_seq_len = args.max_seq_lenself.embed = ParallelEmbedding(args.vocab_size, args.dim)self.layers = torch.nn.ModuleList()for layer_id in range(args.n_layers):self.layers.append(Block(layer_id, args))self.norm = RMSNorm(args.dim)self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int = 0):'''params tokens: 输入文本内容的id表示 shape(batch_size, seq_len).return:输出文本词的logits表示 shape(batch_size, vocab_size)'''seqlen = tokens.size(1) # tokens数目h = self.embed(tokens) # tokens需要embedding转换成词向量freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]mask = Noneif seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)for layer in self.layers:# 多个TransformerBlock计算h = layer(h, start_pos, freqs_cis, mask)h = self.norm(h)[:, -1]logits = self.head(h)if world_size > 1:all_logits = [torch.empty_like(logits) for _ in range(world_size)]dist.all_gather(all_logits, logits)logits = torch.cat(all_logits, dim=-1)return logits
2. TransformerBlock结构:和上图类似
class Block(nn.Module):"""论文中TransformerBlock的结构Attention部分即self.attn,采用了MLA技术Feed-Forward Network部分即self.ffn用的是MLP或者MOE,刚开始几个是dense_layer使用MLP,之后就是transerformerlayer使用MOE. """def __init__(self, layer_id: int, args: ModelArgs):super().__init__()self.attn = MLA(args)self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)self.attn_norm = RMSNorm(args.dim)self.ffn_norm = RMSNorm(args.dim)def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)x = x + self.ffn(self.ffn_norm(x))return x
3.MLP详解
class MLP(nn.Module):'''MLP就是denslayer的ffn部分,就是一系列线性变换大致是W2@(SILU(W1@x)*(W3@x))'''def __init__(self, dim: int, inter_dim: int):super().__init__()self.w1 = ColumnParallelLinear(dim, inter_dim)self.w2 = RowParallelLinear(inter_dim, dim)self.w3 = ColumnParallelLinear(dim, inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:return self.w2(F.silu(self.w1(x)) * self.w3(x))
4.MLA详解
简单说来就是一种新的计算QKV的方式,原始的QKV计算是通过3个矩阵运算对hidden state分别计算QKV,KV需要缓存在网络中,现在通过一个中间步骤计算QKV,不直接缓存KV而是缓存下图阴影部分,减少了 K 和 V 矩阵的存储和计算开销
计算最终输出的时候有“navie”和“absorb”方式,代码实际用的是absorb方式,区别在于navie模式模型存贮cache的是k和v,而absorb方式存储的是kv_cache和pe_cache,这两种方式计算本质是等价的aborb方式是navie的展开以方便使用kv_cache和pe_cache
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):bsz, seqlen, _ = x.size()end_pos = start_pos + seqlen# 1. 计算qif self.q_lora_rank == 0:q = self.wq(x)else:q = self.wq_b(self.q_norm(self.wq_a(x)))q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)# 2. 拆分qq_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)q_pe = apply_rotary_emb(q_pe, freqs_cis)# 3.计算kvkv = self.wkv_a(x)# 4.拆分kkv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)if attn_impl == "naive": # 正常的kv cacheq = torch.cat([q_nope, q_pe], dim=-1)kv = self.wkv_b(self.kv_norm(kv))kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)self.k_cache[:bsz, start_pos:end_pos] = kself.v_cache[:bsz, start_pos:end_pos] = vscores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scaleelse: # 实际运行的是这里,可以cache的不再是完整的kv结果wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scaleif mask is not None:scores += mask.unsqueeze(1)scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)if attn_impl == "naive":x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])else:x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])x = self.wo(x.flatten(2))return x
5.MOE详解