当前位置: 首页 > news >正文

大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()# RMSNorm的参数gself.weight = nn.Parameter(torch.ones(hidden_size))# 防止分母为0self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states*torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1,  x2 = x.chunk(2, dim=-1)return torch.cat((x1, x2), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q*cos) + (rotate_half(q)*sin)k_embed = (k*cos) + (rotate_half(k)*cos)return q_embed, k_embedclass RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float()/dim))t= torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size):super().__init__()# 隐藏层维度self.dim = dim# attention head数self.n_heads = n_heads# q低秩压缩到的维度self.q_lora_rank = q_lora_rank# k/v低秩压缩到的维度self.kv_lora_rank = kv_lora_rank# q/k不带旋转位置编码的维度self.qk_nope_head_dim = qk_nope_head_dim# q/k带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim# q/k的总维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim# v的维度self.v_head_dim = v_head_dimself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads*self.qk_head_dim)self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads*(self.qk_nope_head_dim + self.v_head_dim))self.wo = nn.Linear(self.n_heads*self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)self.register_buffer("kv_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank))self.register_buffer("pe_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim))def forward(self, x, mask=None):bs, seq_len, _ = x.shape# [bs, seq_len, q_lora_rank]q = self.wq_a(x)# [bs, seq_len, q_lora_rank]q = self.q_norm(q)# [bs, seq_len, n_heads*(qk_nope_head_dim+qk_rope_head_dim)]q = self.wq_b(q)# [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)]q = q.view(bs, seq_len, self.n_heads,  self.qk_head_dim)# 按照最后一个维度进行切分#                                                                 --> [bs, seq_len, n_heads, qk_nope_head_dim]#                                                               --# [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)] --#                                                               --#                                                                 --> [bs, seq_len, n_heads, qk_rope_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv = self.wkv_a(x)# 按照最后一个维度进行切分#                                                    --> [bs, seq_len, kv_lora_rank]#                                                  --# [bs, seq_len, kv_lora_rank + qk_rope_head_dim] --#                                                  --#                                                    --> [bs, seq_len, qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# 和q的维度保持一致,[bs, seq_len, 1, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)# 旋转位置编码q_pe, k_pe = self.rotary_emb(q_pe, k_pe)# 重新压缩为原来的维度 [bs, seq_len, qk_rope_head_dim]k_pe = k_pe.squeeze(2)kv = self.kv_norm(kv)# 缓存共同作用于k和v的矩阵,该矩阵用于对k和v升维self.kv_cache[:bs, :seq_len, :] = kv# 缓存用于计算旋转位置编码部分的k矩阵self.pe_cache[:bs, :seq_len, :] = k_pe# [n_heads*(qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = self.wkv_b.weight# [n_heads, (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)# #################################MLA的核心################################## q_nope可简单理解成x*w_q,然后再乘以w_k,即x*w_q*w_k,计算结果的shape为[bs, seq_len, n_heads, qk_nope_head_dim)q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 再乘以k,这里的k是降维之后的x,即对x作用了一个降维矩阵wkv_a,计算结果的shape为[bs, seq_len, n_heads, seq_len]# 得到非旋转位置编码部分q和k的相似度scores_nope = torch.einsum("bshc, btc->bsht", q_nope, self.kv_cache[:bs, :seq_len, :])# 得到旋转位置编码部分q和k的相似度,计算结果的shape为[bs, seq_len, n_heads, seq_len]scores_pe = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bs, :seq_len, :])# #################################MLA的核心################################## 将两个部分的得分值加起来,然后再进行scalescores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)if mask is not None:scores += mask.unseqeeze(2)scores = scores.softmax(dim=-1)# k和v的相似度计算好了之后就要和v计算了,那v是由kv矩阵和wkv_b矩阵中的一部分计算得到的# 先同kv矩阵计算,shape为[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bs, :seq_len,:])# 再同wkv_b[:, -self.v_head_dim:]计算,shape为[bs, seq_len, n_heads, v_head_dim]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])x = x.contiguous().view(bs, seq_len, -1)x = self.wo(x)return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size)print(mla(x))print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

llm_related/deepseek_learn at main · wyf3/llm_related · GitHub


http://www.mrgr.cn/news/90205.html

相关文章:

  • directx12 3d开发过程中出现的报错 九
  • C++ Primer sizeof运算符
  • C++ 继承(1)
  • Pyqt QScrollArea组件
  • 【Pytorch函数】PyTorch随机数生成全解析 | torch.rand()家族函数使用指南
  • 来 Gitcode 免费体验 DeepSeek 蒸馏模型,开启 AI 探索新旅程
  • nodejs - vue 视频切片上传,本地正常,线上环境导致磁盘爆满bug
  • DeepSeek模拟阿里面试——java面向对象
  • 以创新技术驱动Creo许可优化,赋能企业高效设计
  • JavaEE架构
  • AutoGen实战应用
  • 防火墙用户认证实验
  • zynq tcp万兆网和ftp协议分析
  • k8s部署elasticsearch
  • 【数据结构】双向链表(真正的零基础)
  • 【故障处理】ORA-19849 ORA-19612 0RA-17627 ORA-03114
  • 算法之 博弈问题
  • 工厂方法模式详解(Java)
  • 元数据、数据元、数据元素、数据项 和 主数据的概念
  • 荣耀手机Magic3系列、Magic4系列、Magic5系列、Magic6系列、Magic7系列详情对比以及最新二手价格预测
  • 数据结构与算法(test3)
  • MySQL主从同步+binlog
  • python学习目录
  • spring学习(druid、c3p0的数据源对象管理)(案例学习)
  • 【故障处理】ADG延迟 - MRP0状态为WAIT_FOR_LOG
  • vscode无法ssh连接远程机器解决方案