大模型推理——MLA实现方案
1.整体流程
先上一张图来整体理解下MLA的计算过程
2.实现代码
import math
import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()# RMSNorm的参数gself.weight = nn.Parameter(torch.ones(hidden_size))# 防止分母为0self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states*torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((x1, x2), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q*cos) + (rotate_half(q)*sin)k_embed = (k*cos) + (rotate_half(k)*cos)return q_embed, k_embedclass RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float()/dim))t= torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size):super().__init__()# 隐藏层维度self.dim = dim# attention head数self.n_heads = n_heads# q低秩压缩到的维度self.q_lora_rank = q_lora_rank# k/v低秩压缩到的维度self.kv_lora_rank = kv_lora_rank# q/k不带旋转位置编码的维度self.qk_nope_head_dim = qk_nope_head_dim# q/k带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim# q/k的总维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim# v的维度self.v_head_dim = v_head_dimself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads*self.qk_head_dim)self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads*(self.qk_nope_head_dim + self.v_head_dim))self.wo = nn.Linear(self.n_heads*self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)self.register_buffer("kv_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank))self.register_buffer("pe_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim))def forward(self, x, mask=None):bs, seq_len, _ = x.shape# [bs, seq_len, q_lora_rank]q = self.wq_a(x)# [bs, seq_len, q_lora_rank]q = self.q_norm(q)# [bs, seq_len, n_heads*(qk_nope_head_dim+qk_rope_head_dim)]q = self.wq_b(q)# [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)# 按照最后一个维度进行切分# --> [bs, seq_len, n_heads, qk_nope_head_dim]# --# [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)] --# --# --> [bs, seq_len, n_heads, qk_rope_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv = self.wkv_a(x)# 按照最后一个维度进行切分# --> [bs, seq_len, kv_lora_rank]# --# [bs, seq_len, kv_lora_rank + qk_rope_head_dim] --# --# --> [bs, seq_len, qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# 和q的维度保持一致,[bs, seq_len, 1, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)# 旋转位置编码q_pe, k_pe = self.rotary_emb(q_pe, k_pe)# 重新压缩为原来的维度 [bs, seq_len, qk_rope_head_dim]k_pe = k_pe.squeeze(2)kv = self.kv_norm(kv)# 缓存共同作用于k和v的矩阵,该矩阵用于对k和v升维self.kv_cache[:bs, :seq_len, :] = kv# 缓存用于计算旋转位置编码部分的k矩阵self.pe_cache[:bs, :seq_len, :] = k_pe# [n_heads*(qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = self.wkv_b.weight# [n_heads, (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)# #################################MLA的核心################################## q_nope可简单理解成x*w_q,然后再乘以w_k,即x*w_q*w_k,计算结果的shape为[bs, seq_len, n_heads, qk_nope_head_dim)q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 再乘以k,这里的k是降维之后的x,即对x作用了一个降维矩阵wkv_a,计算结果的shape为[bs, seq_len, n_heads, seq_len]# 得到非旋转位置编码部分q和k的相似度scores_nope = torch.einsum("bshc, btc->bsht", q_nope, self.kv_cache[:bs, :seq_len, :])# 得到旋转位置编码部分q和k的相似度,计算结果的shape为[bs, seq_len, n_heads, seq_len]scores_pe = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bs, :seq_len, :])# #################################MLA的核心################################## 将两个部分的得分值加起来,然后再进行scalescores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)if mask is not None:scores += mask.unseqeeze(2)scores = scores.softmax(dim=-1)# k和v的相似度计算好了之后就要和v计算了,那v是由kv矩阵和wkv_b矩阵中的一部分计算得到的# 先同kv矩阵计算,shape为[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bs, :seq_len,:])# 再同wkv_b[:, -self.v_head_dim:]计算,shape为[bs, seq_len, n_heads, v_head_dim]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])x = x.contiguous().view(bs, seq_len, -1)x = self.wo(x)return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size)print(mla(x))print(mla.kv_cache)
参考资料:
https://zhuanlan.zhihu.com/p/16730036197
llm_related/deepseek_learn at main · wyf3/llm_related · GitHub