当前位置: 首页 > news >正文

deepseek v3网络结构源码分析笔记

 1.网络主结构代码:主要是循环n_layers个TransformerBlock,在self.layers内构建

class Transformer(nn.Module):def __init__(self, args: ModelArgs):global world_size, rankworld_size = dist.get_world_size() if dist.is_initialized() else 1rank = dist.get_rank() if dist.is_initialized() else 0Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16super().__init__()self.max_seq_len = args.max_seq_lenself.embed = ParallelEmbedding(args.vocab_size, args.dim)self.layers = torch.nn.ModuleList()for layer_id in range(args.n_layers):self.layers.append(Block(layer_id, args))self.norm = RMSNorm(args.dim)self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int = 0):'''params tokens: 输入文本内容的id表示 shape(batch_size, seq_len).return:输出文本词的logits表示 shape(batch_size, vocab_size)'''seqlen = tokens.size(1) # tokens数目h = self.embed(tokens) # tokens需要embedding转换成词向量freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]mask = Noneif seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)for layer in self.layers:# 多个TransformerBlock计算h = layer(h, start_pos, freqs_cis, mask)h = self.norm(h)[:, -1]logits = self.head(h)if world_size > 1:all_logits = [torch.empty_like(logits) for _ in range(world_size)]dist.all_gather(all_logits, logits)logits = torch.cat(all_logits, dim=-1)return logits

2. TransformerBlock结构:和上图类似

class Block(nn.Module):"""论文中TransformerBlock的结构Attention部分即self.attn,采用了MLA技术Feed-Forward Network部分即self.ffn用的是MLP或者MOE,刚开始几个是dense_layer使用MLP,之后就是transerformerlayer使用MOE.  """def __init__(self, layer_id: int, args: ModelArgs):super().__init__()self.attn = MLA(args)self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)self.attn_norm = RMSNorm(args.dim)self.ffn_norm = RMSNorm(args.dim)def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)x = x + self.ffn(self.ffn_norm(x))return x

3.MLP详解

class MLP(nn.Module):'''MLP就是denslayer的ffn部分,就是一系列线性变换大致是W2@(SILU(W1@x)*(W3@x))'''def __init__(self, dim: int, inter_dim: int):super().__init__()self.w1 = ColumnParallelLinear(dim, inter_dim)self.w2 = RowParallelLinear(inter_dim, dim)self.w3 = ColumnParallelLinear(dim, inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:return self.w2(F.silu(self.w1(x)) * self.w3(x))

4.MLA详解

简单说来就是一种新的计算QKV的方式,原始的QKV计算是通过3个矩阵运算对hidden state分别计算QKV,KV需要缓存在网络中,现在通过一个中间步骤计算QKV,不直接缓存KV而是缓存下图阴影部分,减少了 K 和 V 矩阵的存储和计算开销

计算最终输出的时候有“navie”和“absorb”方式,代码实际用的是absorb方式,区别在于navie模式模型存贮cache的是k和v,而absorb方式存储的是kv_cache和pe_cache,这两种方式计算本质是等价的aborb方式是navie的展开以方便使用kv_cache和pe_cache

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):bsz, seqlen, _ = x.size()end_pos = start_pos + seqlen# 1. 计算qif self.q_lora_rank == 0:q = self.wq(x)else:q = self.wq_b(self.q_norm(self.wq_a(x)))q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)# 2. 拆分qq_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)q_pe = apply_rotary_emb(q_pe, freqs_cis)# 3.计算kvkv = self.wkv_a(x)# 4.拆分kkv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)if attn_impl == "naive": # 正常的kv cacheq = torch.cat([q_nope, q_pe], dim=-1)kv = self.wkv_b(self.kv_norm(kv))kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)self.k_cache[:bsz, start_pos:end_pos] = kself.v_cache[:bsz, start_pos:end_pos] = vscores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scaleelse: # 实际运行的是这里,可以cache的不再是完整的kv结果wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scaleif mask is not None:scores += mask.unsqueeze(1)scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)if attn_impl == "naive":x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])else:x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])x = self.wo(x.flatten(2))return x

5.MOE详解


http://www.mrgr.cn/news/90045.html

相关文章:

  • array_walk. array_map. array_filter
  • idea中git版本回退
  • AtCoder Beginner Contest 392(ABCDE)
  • python Excel 表读取合并单元格以及清除空格符
  • C基础寒假练习(6)
  • 具身智能学习规划
  • 5. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Nacos
  • Win10 部署llama Factory 推荐教程和遇到的问题
  • 大数据项目4:基于spark的智慧交通项目设计与实现
  • 【通俗易懂说模型】反向传播(附多元分类与Softmax函数)
  • 【虚幻引擎UE】UE4.23到UE5.5的核心功能变化
  • LLMs之DeepSeek r1:TinyZero(复现 DeepSeek R1 Zero 的核心功能)的简介、安装和使用方法、案例应用之详细攻略
  • [概率论] 随机变量
  • CPLD实现SPI通信
  • android系统的overlay机制
  • Odoo17学习笔记
  • 题解 洛谷 Luogu P1983 [NOIP 2013 普及组] 车站分级 拓扑排序 C++
  • 【1.05版】wordpressAI插件批量生成文章、图片、长尾关键词、文章采集、AI对话等
  • fps动作系统5:角色冲刺
  • [M模拟] lc380. O(1) 时间插入、删除和获取随机元素(模拟+数据结构+脑筋急转弯+数组快捷删除技巧+项目思考)
  • Maven入门核心知识点总结
  • 【Matlab优化算法-第14期】基于智能优化算法的VMD信号去噪项目实践
  • Java虚拟机面试题:类加载机制
  • 深入理解Java三大特性:封装、继承和多态
  • 【STM32基础】STM32F4 USB通信之HID设备(基于CubeMX)
  • 51单片机俄罗斯方块计分函数