均方根层标准化(RMSNorm: Root Mean Square Layer Normalization)
文章目录
- 0 TL;DR
- 1 背景
- 2 理论
- 3 效果
- 4 LayerNorm的重新中心化到底有没有用?
- 5 Torch源码
0 TL;DR
LayerNorm的重新中心化可能不是必要的,RMSNorm移除了重新中心化,降低了计算量。实验显示,训练效果和稳定性也更好!
1 背景
LayerNorm存在什么问题? 标准化目的是使训练更快,但标准化增加了计算量,降低了标准化带来的训练速度收益。
如果LayerNorm的中心化不是必要的,移除中心化是不是就减少了计算量!
2 理论
先来回顾一下LayerNorm:
神经网络的前馈网络通过线性变化➕非线性激活对输入进行投影变换:
a i = ∑ j = 1 m w i , j x j , y i = f ( a i + b i ) a_i=\sum_{j=1}^{m}w_{i,j}x_j, \quad y_i=f(a_i+b_i) ai=j=1∑mwi,jxj,yi=f(ai+bi)
但后续网络层的输入分布会变化,出现协变量偏移问题,降低了模型收敛速度。LayerNorm对加权输入 a \boldsymbol a a进行归一化,固定其均值和方差:
a ˉ i = a i − μ σ g i , y i = f ( a ˉ i + b i ) \bar a_i=\frac{a_i - \mu}{\sigma}g_i, \quad y_i=f(\bar a_i+b_i) aˉi=σai−μgi,yi=f(aˉi+bi)
其中, μ \mu μ和 σ \sigma σ分别是加权输入 a \boldsymbol a a的均值和标准差估计量。
LayerNorm具有重新中心化和重新缩放不变性。重新中心化对输入和权重的偏移噪声不敏感(抗偏移),重新缩放则能在输入和权重随机缩放时保证输出不变(抗伸缩)。
RMSNorm仅关注重新缩放不变性,根据均方根(RMS)统计量对加权输入进行正则化:
a ˉ i = a i RMS ( a ) g i , where RMS ( a ) = 1 n ∑ i = 1 1 a i 2 \bar a_i=\frac{a_i}{\text{RMS}(\boldsymbol a)}g_i,\quad \text{where}\ \text{RMS}(\boldsymbol a)=\sqrt{\frac{1}{n}\sum_{i=1}^1a_i^2} aˉi=RMS(a)aigi,where RMS(a)=n1i=1∑1ai2
当加权输入 a \boldsymbol a a的均值为零时,LayerNorm和RMSNorm完全等价。
由于RMSNorm不需要计算均值,简化了计算量,提升了训练速度!假设中心化对训练的影响很小,这就是可行的!
伸缩不变性:
RMS具有性质: RMS ( α x ) = α RMS ( x ) \text{RMS}(\alpha \boldsymbol x) = \alpha\text{RMS}(\boldsymbol x) RMS(αx)=αRMS(x),因此,对于RMSNorm的通用形式:
y = f ( W x RMS ( a ) ⊙ g + b ) \boldsymbol y=f\bigg(\frac{W\boldsymbol x}{\text{RMS}(\boldsymbol a)}\odot\boldsymbol g+\boldsymbol b\bigg) y=f(RMS(a)Wx⊙g+b)
有:
y ′ = f ( α W x RMS ( α W x ) ⊙ g + b ) = f ( W x RMS ( W x ) ⊙ g + b ) = y ′ \boldsymbol y' =f\bigg(\frac{\alpha W \boldsymbol x}{\text{RMS}(\alpha W \boldsymbol x)}\odot\boldsymbol g+\boldsymbol b\bigg) =f\bigg(\frac{W\boldsymbol x}{\text{RMS}(W x)}\odot\boldsymbol g+\boldsymbol b\bigg) =\boldsymbol y' y′=f(RMS(αWx)αWx⊙g+b)=f(RMS(Wx)Wx⊙g+b)=y′
3 效果
![](https://i-blog.csdnimg.cn/direct/bc3ee3a78f184bd78de93af13c184d9b.png#pic_center)
![](https://i-blog.csdnimg.cn/direct/ac5510e4a5ab4553a1d381b1211fe4ad.png#pic_center)
4 LayerNorm的重新中心化到底有没有用?
如下图所示,将网络权重均值随机初始化为0.2,LayerNorm收敛非常慢,但RMSNorm仍正常工作,可见 RMSNorm对权重初始均值的变化更加鲁棒 !
![](https://i-blog.csdnimg.cn/direct/d1499ebb196e44dcb7f6f7ead95ffc31.png#pic_center)
5 Torch源码
Qwen的实现,与T5 LayerNorm一致。
class Qwen2RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""Qwen2RMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)