当前位置: 首页 > news >正文

【笔记】KV-cache

KV-cache

  • KV-cache原理
  • 计算公式
  • 举个例子

Transformer原理讲解参考此文

KV-cache原理

在这里插入图片描述
解码时这三个矩阵的大小不同,事实上,通常是一个向量,而是矩阵。向量表示新的token
在这里插入图片描述

在注意力机制中,我们首先对查询向量(query vector)和键矩阵(key matrix)进行点积运算,然后对结果应用 s o f t m a x softmax softmax,将其作为对值矩阵(value matrix)的加权和。在自回归解码(Auto-regressive decoding)中,模型在给定所有之前的上下文的情况下逐词生成,因此键矩阵 K K K和值矩阵 V V V包含整个序列的信息,而查询向量(query vector)仅包含我们刚刚看到的最后一个词的信息。你可以将查询向量与键矩阵的点积理解为在当前词与所有之前的词之间同时进行的注意力操作。因为我们是逐词生成序列,所以键矩阵和值矩阵实际上并没有太多变化,这个词对应键矩阵的一列和值矩阵的一行。关键的一点是,一旦我们计算出该词的嵌入表示,它将不会再改变,无论后面生成多少词。然而,模型仍然需要在每个后续步骤中计算该词的键 k k k和值向量 v v v,这导致了矩阵向量乘法的数量呈二次增长,从而导致非常缓慢的计算速度。
在这里插入图片描述
在这里插入图片描述
当模型读取一个新单词时,它会像以前一样生成查询 Vector,但缓存了键和值矩阵的先前值,因此不再需要为以前的上下文计算这些向量,相反,只需要为键 Matric 计算一个新列,为值 Matrix 计算一个新行,然后像往常一样继续使用点积和 s o f t m a x softmax softmax 来计算注意力。
在这里插入图片描述
KV 缓存在 Transformer 中的作用是加速自注意力层的计算。通常,自注意力层会处理整个嵌入序列,而通过 KV 缓存,只需要传入之前的 k k k v v v 缓存以及当前词的新嵌入。自注意力层会根据当前词的嵌入计算新的键和值向量,并将它们追加到 KV 缓存中。这样,无需重复计算之前的键和值向量,从而显著加快了计算速度,同时仍然可以生成高效的自回归解码序列。
在这里插入图片描述
接下来,需要将这些键和值矩阵存储在 GPU 的内存中,以便在处理下一个词时可以随时检索它们。请注意,模型中当前词与前一个词交互的唯一部分是自注意力层。在其他层(例如位置嵌入、层归一化和前馈神经网络)中,当前词与之前的上下文没有交互。因此,当使用 KV 缓存时,对于每个新词只需执行恒定量的计算工作,而且即使序列长度增加,这些计算量也不会随之增加。

计算公式

K V − c a c h e = 2 ∗ P r e c i s i o n ∗ N l a y e r s ∗ D m o d e l ∗ S e q l e n ∗ B a t c h KV-cache = 2*Precision*N_{layers}*D_{model}*Seqlen*Batch KVcache=2PrecisionNlayersDmodelSeqlenBatch
其中,
2 2 2=two matrices for K K K and V V V
P r e c i s o n Precison Precison=bytes per parameter(eg:4 for fp32)
N l a y e r N_{layer} Nlayer = layer in the model
D m o d e l D_{model} Dmodel = dimension of the embeddings
S e q l e n Seqlen Seqlen=length of context in tokens
B a t c h Batch Batch = batch sieze

举个例子

Example :OPT-30B
2 2 2 = two matrices for K and V
P r e c i s o n Precison Precison = 2(fp16)
N l a y e r N_{layer} Nlayer = 48
D m o d e l D_{model} Dmodel = 7168
S e q l e n Seqlen Seqlen =1024
B a t c h Batch Batch = 128
所以,KV-cache = 224871681024*128 = 180G


http://www.mrgr.cn/news/63250.html

相关文章:

  • AI-基本概念-多层感知器模型/CNN/RNN/自注意力模型
  • Langchain调用模型使用FAISS
  • 合合信息亮相PRCV大会,探讨生成式AI时代的内容安全与系统构建加速
  • 书生-第四期闯关:完成SSH连接与端口映射并运行hello_world.py
  • 【Spring】Spring 核心和设计思想
  • 如何在Linux系统中使用LVM进行磁盘管理
  • 如何实现PLC系统时钟显示在HMI上?
  • 地下隧道、管廊非接触式二维位移监测裂纹、衬砌、支护结构损伤识别、隧道病害诊断等问题解决方式——变焦视觉位移监测仪
  • C++初阶(八)--初识模板
  • 制作一个简易恒流电子负载教程,实战教程,单片机程序,电路图,方案
  • 基于字符的图片验证码识别算法的设计与实现
  • springcloud通过MDC实现分布式链路追踪
  • 九识智能与徐工汽车达成战略合作,共绘商用车未来新蓝图
  • SAP ABAP开发学习——BADI增强操作步骤示例2
  • 在阿里云快速启动Umami玩转网页分析
  • 一位专科生面上网络安全的经验总结_网络安全专科
  • 视频批量裁剪工具
  • 探索智能投顾:正大金融数据分析如何优化市场策略
  • 【自动化测试】APP UI 自动化(安卓手机)-本地环境搭建
  • SSID,即Service Set Identifier(服务设置的表示符号)
  • CBAM填报攻略:关键点解析与实操案例分享
  • 台式电脑如何改ip地址:全面解析与实操指南
  • 成功解决:notepad++搜索结果窗口不见了,怎么找回?
  • 【无人机设计与控制】四旋翼无人机飞行姿态(ADRC)自抗扰控制Matlab仿真
  • msys2更换国内源(多个文件(不是3个文件的版本!))
  • 2024年重磅综述:探索深度多模态数据融合的学术前沿动态!