旋转位置编码(RoPE)讲解和代码实现
旋转位置编码(Rotary Position Embedding:RoPE)讲解和代码实现
1. 什么是位置编码?
在 Transformer 模型中,位置编码的作用是为模型提供序列中每个 token 的位置信息。因为 Transformer 本身没有像 RNN 那样的顺序结构,所以需要通过位置编码来告诉模型 token 的顺序。
2. 为什么需要 RoPE ?
传统的位置编码有一些局限性:
绝对位置编码
- 正弦函数:固定的函数,无法适应不同任务的需求。不同位置编码之间互相独立,无法体现相对位置关系。
- 可学习的位置嵌入:受限于最大序列长度,无法处理更长的序列。
相对位置编码
- T5模型中的相对位置编码:因为需要添加到注意力机制里,所以导致速度较慢,也导致无法使用KV缓存。
RoPE 是一种更高效的位置编码方法,它通过旋转矩阵将位置信息直接注入到 token 的向量中,支持更长的序列长度,并且性能更好。
3. RoPE 的核心思想
RoPE 的核心思想是:通过旋转矩阵将位置信息融入到 token 的向量表示中。具体来说,它会对每个 token 的向量进行旋转,旋转的角度与 token 的位置相关。
4. RoPE 的具体步骤
详细推导过程可以看文章:旋转位置编码(ROPE)公式详细推导过程
以下是 RoPE 的具体实现步骤:
(1)定义旋转矩阵
对于位置 m m m 和维度 i i i,旋转矩阵 R m \mathbf{R}_m Rm 定义为:
R m = ( cos m θ i − sin m θ i sin m θ i cos m θ i ) \mathbf{R}_m = \begin{pmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{pmatrix} Rm=(cosmθisinmθi−sinmθicosmθi)
其中, θ i = 1000 0 − 2 i / d \theta_i = 10000^{-2i/d} θi=10000−2i/d, d d d 是向量的维度。
(2)对 query 和 key 进行旋转
对于 query 向量 q \mathbf{q} q 和 key 向量 k \mathbf{k} k,分别对它们的每一维进行旋转:
q m = R m q \mathbf{q}_m = \mathbf{R}_m \mathbf{q} qm=Rmq
k n = R n k \mathbf{k}_n = \mathbf{R}_n \mathbf{k} kn=Rnk
其中, m m m 和 n n n 分别是 query 和 key 的位置。
(3)计算注意力分数
旋转后的 query 和 key 用于计算注意力分数:
Attention Score = ( R m q ) T ( R n k ) \text{Attention Score} = (\mathbf{R}_m \mathbf{q})^T (\mathbf{R}_n \mathbf{k}) Attention Score=(Rmq)T(R