【PyTorch][chapter31][transformer-5] MQA,CQA, GQA
前言:
Trans
翻译 《Variants of Multi-head attention: Multi-query (MQA) and Grouped-query attention (GQA)
为例兼顾性能,和模型的效率Google 又陆续提出了三种注意力架构.
当一个模型训练过度时,它会过度拟合或记忆训练数据,从而降低其分析相似但不同输入的能力。但如果继续训练会怎样呢?一项新的研究发现,过度拟合并不是终点。
新进展:OpenAI的Alethea Power及其同事在算法生成的数据集上训练了相对较小的架构,他们观察到持续的训练会导致一种他们称之为“顿悟(grokking)”的效应,即在过度拟合发生后很久,Transformer模型才展现出对新数据的泛化能力。
关键见解:研究在数百万个样本的数据集上训练的数十亿参数的模型如何随时间学习进展,需要大量的计算。同样具有启发性且更实用的是研究在数千个样本上训练的数十万参数的模型。这种规模的模型可以在更短的时间内进行更多步骤的训练。
工作原理:作者训练了一组Transformer来对12个二元方程(主要是多项式)的每个解进行分类。
对于每个方程,他们为两个变量插入可能的值以找到所有可能的解。这产生了大约10,000个输入-输出对,这些对将被分配到训练集、测试集和验证集中。
为了将方程输入到Transformer中,他们以类似于23=6的形式表示每个方程,但用符号替换每个标记;例如,用a替换2,用m替换,用b替换3,用q替换=,等等。
他们继续训练,远远超出了训练准确率提高而验证准确率降低的点,这是过度拟合的典型迹象。
结果:随着模型的训练,验证准确率上升、下降
,然后在训练步骤数量继续增加1000倍后,再次上升。
(在模除的情况下,验证准确率从近5%提高到近100%)。在使用缩减数据集的实验中,作者发现训练集越小,实现第二次准确率提升所需的训练就越多。例如,当使用30%的样本进行训练时,大约需要增加45%的训练步骤。
为什么重要:顿悟可能是小型模型和数据集在参数数量或训练样本数量增加时,模型性能先提升、后下降、再提升(即双重下降)的展现方式。也就是说,这项工作提供了证据表明,我们对过度拟合的意义一直存在误解。模型在过度拟合后仍然可以继续学习,并最终变得相当强大。
我们的思考:作者是在一个小规模的环境中发现了这一现象。现在我们需要找出它在真实大小和规模的模型和数据集上是否仍然成立。
目录:
- MHA(multi-head Attention)
- MQA(multi-query Attention)
- GQA(Group-Query Attention)
- 三种架构性能比较
- Transformer Overfit
- 代码
一 MHA(multi-head attention)
MHA(多头注意力)是在2017年由Vaswani等人发表的题为《Attention is All You Need》的Transformer论文中提出的。
这是使用多个查询(Query)、键(Key)、值(Value)矩阵和矩阵的多头注意力(MHA)的示意图。
首先,我们对由三个矩阵Q(查询)、K(键)、V(值)组成的每个头执行点积注意力(Dot Product Attention)。然后,我们将结果拼接(concatenate)起来。现在,让我们来详细展开这些多头及其操作
1.2 流程回顾
1:子空间投影
如果我们有h个头,那么就会得到h个查询(queries)、h个键(keys)和h个值(values),就像这里展示的,其中h等于8。
1.2 执行Scale-dot-product
1.3 多头注意力在进行增量推理时存在哪些缺点呢?
例如,在增量生成中,Transformer需要逐步地生成一些输出文本。在这样的应用中,我们需要反复地将这些大型键(Key)和值(Value)矩阵加载到内存中,因此,内存成为了这些计算中的瓶颈。接下来,我们来看看两种解决内存带宽问题的方法
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 24 20:54:16 2024@author: cxf
"""import torch
import torch.nn as nn
import torch.nn.functional as F class MHA(nn.Module): def __init__(self, embed_dim, num_heads=8, dropout=0.1): super(MHA, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == embed_dim ), "Embedding size needs to be divisible by num_heads" self.query = nn.Linear(embed_dim, embed_dim) self.key = nn.Linear(embed_dim, embed_dim) self.value = nn.Linear(embed_dim, embed_dim) self.out = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): batch_size, seq_length, embed_dim = x.size() # Split the embedding into self.num_heads different pieces q = self.query(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) k = self.key(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) v = self.value(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # Scaled dot-product attention calculation scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) # Apply the attention weights to the values out = torch.matmul(attn_weights, v) # Re-assemble all head outputs side by side out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim) # Final linear layer out = self.out(out) return out # Example usage:
batch_size = 64
seq_length = 100
embed_dim = 512
num_heads = 8 # Create some random input data
x = torch.randn(batch_size, seq_length, embed_dim) # Initialize the GQA module
gqa = MHA(embed_dim, num_heads) # Forward pass
output = MHA(x)
print(output.shape) # Should be [batch_size, seq_length, embed_dim]
二 MQA(multi-query Attention)
MQA(Multi-Query Attention,多查询注意力)架构是在自然语言处理(NLP)领域中, 该论文旨在解决增量推理中的内存带宽问题。我们将看到,使用MQA可以在质量略有下降的情况下显著提高速度。
一、背景与提出
- 背景:在NLP任务中,Transformer架构通过自注意力机制(Self-Attention)实现了对输入序列中各个位置信息的有效捕捉,而多头注意力机制(MHA)则是通过多个自注意力头的并行处理,进一步增强了模型对输入信息的理解和处理能力。
- 提出:MQA架构由Google团队在2019年的论文《Fast Transformer Decoding: One Write-Head is All You Need》中提出,旨在优化Transformer的解码过程,提高推理效率。
二、原理与特点
原理:MQA架构的核心思想是,在Transformer的每一层中,将多个注意力头共享的Key和Value矩阵进行合并,即所有头之间共享同一份Key和Value矩阵,而每个头仍然保留自己独立的Query矩阵。这样做可以显著减少Key和Value矩阵的参数量,从而降低推理过程中的计算负载和内存占用。
MQA被大量应用于LLM中,如ChatGLM2。
特点:
- 参数减少:由于Key和Value矩阵的共享,MQA架构能够大幅减少模型参数的数量。
- 推理加速:参数量的减少带来了推理速度的显著提升,这对于需要快速响应的NLP应用尤为重要。
- 性能损失:然而,MQA架构也会在一定程度上带来模型性能的损失,因为共享Key和Value矩阵可能会降低模型对不同输入信息的区分能力
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 24 17:01:20 2024@author: chengxf2
"""import torch
import torch.nn as nn
import math
class MQA(nn.Module):def __init__(self, d_model=512, num_heads=8):super(MQA,self).__init__()self.num_heads = num_headsself.d_k = d_model//num_headsself.d_model = d_modelself.layer_query = nn.Linear(d_model, d_model)self.layer_key = nn.Linear(d_model, self.d_k)self.layer_value = nn.Linear(d_model, self.d_k)self.layer_out = nn.Linear(d_model, d_model)def projection_down(self, x, num):#投影batchSz, seq_len, d_model = x.shapex = x.view(batchSz, seq_len, num,self.d_k).transpose(1,2)return xdef forward(self, query,key, value, mask=None):#(batchSz, seq_length, d_model)batchSz = query.size(0)q = self.layer_query(query)k = self.layer_key(key)v = self.layer_value(value)q = self.projection_down(q, self.num_heads)k = self.projection_down(k, 1)v = self.projection_down(v, 1)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 多头合并output = output.transpose(1, 2).contiguous().view(batchSz, -1, self.d_model)# 线性变换output = self.layer_out(output)return outputx = torch.rand(3, 12, 512)
model = MQA(512, 8)
y = model(x,x,x,None)
print(y.shape)
三 GQA(Group-Query Attention)
分组查询注意力(Grouped Query Attention,GQA)是一种在大型语言模型中的多查询注意力(Multi-Query Attention,MQA)和多头注意力(Multi-Head Attention,MHA)之间进行插值的方法。它的目标是在保持MQA速度的同时实现MHA的质量,是模型预测表现和模型推理性能之间的一个折衷。
GQA的提出背景在于,MQA虽然能够大幅加速Transformer的推理,但会有明显的性能损失,而MHA虽然模型性能优秀,但推理效率相对较低。因此,GQA通过将查询头分成若干组,每组共享一个键和值,从而在保证一定模型性能的同时,提高了推理效率。
具体来说,GQA将查询头分成G组,每组内的查询头共享相同的键和值。这种方式介于MQA和MHA之间:当G等于1时,GQA退化为MQA,所有查询头共享相同的键和值;当G等于查询头的数量时,GQA退化为MHA,每个查询头都有对应的键和值。
实验结果表明,通过选择合适的分组大小,GQA可以在保持与MHA相近模型性能的同时,实现与MQA相当的推理效率。这对于高负载系统来说可能是必不可少的。
四 三种注意力机制比较
import torch
import torch.nn as nn
import math
class GQA(nn.Module):def __init__(self, d_model=512, num_heads_for_query=8,num_heads_for_key=2,head_dim=64):#Group Head Attentionsuper(GQA,self).__init__()self.num_heads_for_query = num_heads_for_queryself.num_heads_for_key = num_heads_for_keyself.num_head_groups = num_heads_for_query//num_heads_for_keyself.head_dim = head_dimself.d_model = d_modelprint("\n num_heads_for_query %d num_heads_for_key %d num_head_groups%d"%(num_heads_for_query,num_heads_for_key,self.num_head_groups ))self.layer_query = nn.Linear(d_model, num_heads_for_query*head_dim)self.layer_key = nn.Linear(d_model, num_heads_for_key*head_dim)self.layer_value = nn.Linear(d_model, num_heads_for_key*head_dim)self.layer_out = nn.Linear(d_model, d_model)def projection_down(self, x, num_heads):#投影batch_size, seq_len, d_model = x.shapex = x.view(batch_size, seq_len, num_heads,self.head_dim).transpose(1,2)return xdef forward(self, query,key, value, mask=None):#(batchSz, seq_length, d_model)batch_size,seq_len,d_model = query.shapeq = self.layer_query(query)k = self.layer_key(key)v = self.layer_value(value)#投影q = self.projection_down(q, self.num_heads_for_query)k = self.projection_down(k, self.num_heads_for_key)v = self.projection_down(v, self.num_heads_for_key)#分Group#(batch_size,num_heads,seq_len,head_dim)q = q.view(batch_size, self.num_head_groups,self.num_heads_for_key, seq_len,self.head_dim)#(batch_size, num_heads_for_key,seq_len, head_dim),利用PyTorch BroadCast 机制,相当于在某个维度做torch.repeatk = k.unsqueeze(1)v = v.unsqueeze(1)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)print("\n output",output.shape)# 多头合并output = output.view(batch_size, self.num_heads_for_query,seq_len,self.head_dim)output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)print(output.shape)# 线性变换output = self.layer_out(output)return output# 定义查询、键和值
d_model =512
batch_size = 1
seq_len = 256
num_heads_for_key = 2
num_heads_for_query = 8
head_dim = 64
query = torch.randn(batch_size, seq_len, num_heads_for_query, head_dim)
key = torch.randn(batch_size, seq_len, num_heads_for_key, head_dim)
value = torch.randn(batch_size, seq_len, num_heads_for_key, head_dim) model = GQA(d_model, num_heads_for_query,num_heads_for_key,head_dim)query=key=value = torch.rand(batch_size,seq_len,d_model)
out = model(query,key,value,None)
四 Transformer Overfit
Transformer 因为参数量过大,当数据集不大的时候很容易发生过拟合。
在训练集上面ACC 达到100%,在测试集上面只有50%左右。 下面是几个例子
例子1
例子2
例子3:
我在42.3MB 的参数量上,验证了4.3K的数据集很快遇到了该问题。
调了一周的参数,包括学习率,batch_size, num_heads, 模型input_size,output_size
,warmup. 作用不是很明显。
我看解释是参数量过大的时候,模型很容易记住数据本身,而不是学习到数据的分布。
最好的解决方案,就是增加数据集大小,或者增加数据种的噪声。
解决方案:
1:
2: 使用不同的学习率,batch_size 包括warmup 技术
3: 代码检查
4: 可以训练出来(使用不同的采样方法)
5:重新分割(无效)
6:无效
7: 降低模型的复杂度(包括dropout, ffn_hiddens_size)
五 代码实现
https://zhuanlan.zhihu.com/p/686687252
梯度warmup 技术
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler class WarmupCosineLR(_LRScheduler): def __init__(self, optimizer, total_steps, warmup_steps=0, eta_min=0, last_epoch=-1): self.total_steps = total_steps self.warmup_steps = warmup_steps self.eta_min = eta_min super(WarmupCosineLR, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch < self.warmup_steps: return [(base_lr / self.warmup_steps) * (self.last_epoch + 1) for base_lr in self.base_lrs] else: return [self.eta_min + (base_lr - self.eta_min) * (1 + torch.cos(torch.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps))) / 2 for base_lr in self.base_lrs] # 示例用法
model = ... # 你的Transformer模型
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 初始化优化器
total_steps = 10000 # 总训练步数
warmup_steps = 1000 # 预热步数
scheduler = WarmupCosineLR(optimizer, total_steps, warmup_steps) # 训练循环
for epoch in range(num_epochs): for step, (inputs, labels) in enumerate(dataloader): # 前向传播、损失计算、反向传播和优化步骤 optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() # 更新学习率 scheduler.step() # 打印日志或进行其他操作 if (step + 1) % log_interval == 0: print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{step + 1}/{len(dataloader)}], Loss: {loss.item():.4f}, LR: {scheduler.get_lr()[0]:.6f}')
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 31 14:27:32 2024@author: chengxf2
"""import torch
import math
import torch.nn as nn
import torch.optim as optim#https://zhuanlan.zhihu.com/p/686687252
#https://blog.csdn.net/Python_paipai/article/details/141270854class PositionalEncoding(nn.Module): def __init__(self, d_model=90, max_seq_length=1200): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_seq_length, d_model) position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float()*-(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x): #x.shape [batch_size, seq_len,d_model]outputs = x + self.pe[:, :x.size(1)] return outputsclass MultiHeadAttention(nn.Module):'''MultiHead注意力机制通过将原始的输入向量分割成多个头(head),允许每个头独立地学习不同的注意力权重。这样,模型能够同时关注输入序列中不同位置的信息,从而增强了模型对输入数据的表达能力。每个头都可以捕捉到输入数据中的不同特征或模式,这些特征或模式在单个头的情况下可能难以被捕捉到。传统CSI 特征处理: 也是在时域空间,PCA空间, 傅里叶空间,STFT空间,DWT空间进行特征处理'''def __init__(self, d_model=512, num_heads=8):super(MultiHeadAttention, self).__init__()assert d_model%num_heads==0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model//num_heads#分割成不同的头,对应不同的空间self.linear_query = nn.Linear(d_model, d_model)self.linear_key = nn.Linear(d_model, d_model)self.linear_value = nn.Linear(d_model, d_model)self.linear_out = nn.Linear(d_model, d_model)def scaledDotProductAttention(self, q, k, v, mask=None):#无线感知领域也存在这种信息互相关性分析,比如同一根天线不同子载波CSI商#通过这种信息的互相关,进行一个信息的融合,得到一个新的特征#注意力分数#input.shape: [batch_size, num_heads,seq_len, d_k]scores = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(self.d_k)if mask is not None:attn_scores = scores.masked_fill(mask==0,-1e9)#得到注意力概率attn_prop = torch.softmax(attn_scores, dim=-1)#乘以value,得到最终的输出output = torch.matmul(attn_prop,v)return outputdef split_heads(self, x):batch_size ,seq_len, d_model = x.shapeoutput = x.view(batch_size, seq_len,self.num_heads, self.d_k)return output.transpose(1,2)def concate_heads(self, x):batch_size, num_heads, seq_len, d_k = x.shapeoutput = x.transpose(1,2).contiguous().view(batch_size,seq_len,self.d_model)return outputdef forward(self, query, key, value, mask=None):#分割头q = self.linear_query(query)k = self.linear_key(key)v = self.linear_value(value)q = self.split_heads(q)k = self.split_heads(k)v = self.split_heads(v)#执行缩放点积注意力attn_output = self.scaledDotProductAttention(q, k, v,mask)#组合头应用输出变换output = self.linear_out(self.concate_heads(attn_output))return outputclass PositionWiseFeedForward(nn.Module): def __init__(self, d_model=512, d_ff=2048): super(PositionWiseFeedForward, self).__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.relu = nn.ReLU() def forward(self, x): return self.fc2(self.relu(self.fc1(x))) class EncoderLayer(nn.Module):def __init__(self, d_model=90, num_heads=8, d_ff=1024,dropout =0.1):super(EncoderLayer,self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = PositionWiseFeedForward(d_model,d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):attn_output = self.self_attn(x,x,x,mask)#postLNx = self.norm1(x+self.dropout(attn_output))ff_output = self.feed_forward(x)#postLNoutput = self.norm2(x+self.dropout(ff_output))return outputclass DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout): super(DecoderLayer, self).__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.feed_forward = PositionWiseFeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask, tgt_mask): attn_output = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_output)) attn_output = self.cross_attn(x, enc_output, enc_output, src_mask) x = self.norm2(x + self.dropout(attn_output)) ff_output = self.feed_forward(x) x = self.norm3(x + self.dropout(ff_output)) return x class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout): super(Transformer, self).__init__() self.encoder_embedding = nn.Embedding(src_vocab_size, d_model) self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model) self.positional_encoding = PositionalEncoding(d_model, max_seq_length) self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.fc = nn.Linear(d_model, tgt_vocab_size) self.dropout = nn.Dropout(dropout) def generate_mask(self, src, tgt): src_mask = (src != 0).unsqueeze(1).unsqueeze(2) tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3) seq_length = tgt.size(1) nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool() tgt_mask = tgt_mask & nopeak_mask return src_mask, tgt_mask def forward(self, src, tgt): src_mask, tgt_mask = self.generate_mask(src, tgt) src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src))) tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt))) enc_output = src_embedded for enc_layer in self.encoder_layers: enc_output = enc_layer(enc_output, src_mask) dec_output = tgt_embedded for dec_layer in self.decoder_layers: dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask) output = self.fc(dec_output) return output def infer(transformer,criterion):transformer.eval() max_seq_length = 1000src_vocab_size = 90tgt_vocab_size = 90# 生成随机样本验证数据 val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)) # (batch_size, seq_length) val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)) # (batch_size, seq_length) with torch.no_grad(): val_output = transformer(val_src_data, val_tgt_data[:, :-1]) val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1)) print(f"Validation Loss: {val_loss.item()}") def main():src_vocab_size = 1000 tgt_vocab_size = 1000 d_model = 512 num_heads = 8 num_layers = 6 d_ff = 2048 max_seq_length = 100 dropout = 0.0 print("\n main ")transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout) # 生成随机样本数据 src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)) # (batch_size, seq_token_length) tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)) # (batch_size, seq_token_length) criterion = nn.CrossEntropyLoss(ignore_index=0) optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) transformer.train() for epoch in range(100): optimizer.zero_grad() output = transformer(src_data, tgt_data[:, :-1]) #[batch,seq_len,vocab_size]print(output.shape)pred = output.contiguous().view(-1, tgt_vocab_size)label = tgt_data[:, 1:].contiguous().view(-1)print("\n pred.shape",pred.shape, label.shape)loss = criterion(pred, label) loss.backward() optimizer.step() print(f"Epoch: {epoch+1}, Loss: {loss.item()}") main()
https://zhuanlan.zhihu.com/p/686687252
https://www.youtube.com/watch?v=pVP0bu8QA2w
https://www.youtube.com/watch?v=ulmex-d49cM&t=1428s
https://www.youtube.com/watch?v=LgsiwDRnXls&t=653s
https://zhuanlan.zhihu.com/p/686687252
【LLM】一文详解MHA、GQA、MQA原理-CSDN博客