当前位置: 首页 > news >正文

均方根层标准化(RMSNorm: Root Mean Square Layer Normalization)

文章目录

  • 0 TL;DR
  • 1 背景
  • 2 理论
  • 3 效果
  • 4 LayerNorm的重新中心化到底有没有用?
  • 5 Torch源码

0 TL;DR

LayerNorm的重新中心化可能不是必要的,RMSNorm移除了重新中心化,降低了计算量。实验显示,训练效果和稳定性也更好

1 背景

LayerNorm存在什么问题? 标准化目的是使训练更快,但标准化增加了计算量,降低了标准化带来的训练速度收益。

如果LayerNorm的中心化不是必要的,移除中心化是不是就减少了计算量!

2 理论

先来回顾一下LayerNorm:
神经网络的前馈网络通过线性变化➕非线性激活对输入进行投影变换:
a i = ∑ j = 1 m w i , j x j , y i = f ( a i + b i ) a_i=\sum_{j=1}^{m}w_{i,j}x_j, \quad y_i=f(a_i+b_i) ai=j=1mwi,jxj,yi=f(ai+bi)

但后续网络层的输入分布会变化,出现协变量偏移问题,降低了模型收敛速度。LayerNorm对加权输入 a \boldsymbol a a进行归一化,固定其均值和方差:
a ˉ i = a i − μ σ g i , y i = f ( a ˉ i + b i ) \bar a_i=\frac{a_i - \mu}{\sigma}g_i, \quad y_i=f(\bar a_i+b_i) aˉi=σaiμgi,yi=f(aˉi+bi)

其中, μ \mu μ σ \sigma σ分别是加权输入 a \boldsymbol a a的均值和标准差估计量。

LayerNorm具有重新中心化和重新缩放不变性。重新中心化对输入和权重的偏移噪声不敏感(抗偏移),重新缩放则能在输入和权重随机缩放时保证输出不变(抗伸缩)。

RMSNorm仅关注重新缩放不变性,根据均方根(RMS)统计量对加权输入进行正则化:
a ˉ i = a i RMS ( a ) g i , where RMS ( a ) = 1 n ∑ i = 1 1 a i 2 \bar a_i=\frac{a_i}{\text{RMS}(\boldsymbol a)}g_i,\quad \text{where}\ \text{RMS}(\boldsymbol a)=\sqrt{\frac{1}{n}\sum_{i=1}^1a_i^2} aˉi=RMS(a)aigi,where RMS(a)=n1i=11ai2
当加权输入 a \boldsymbol a a的均值为零时,LayerNorm和RMSNorm完全等价。

由于RMSNorm不需要计算均值,简化了计算量,提升了训练速度!假设中心化对训练的影响很小,这就是可行的!

伸缩不变性:
RMS具有性质: RMS ( α x ) = α RMS ( x ) \text{RMS}(\alpha \boldsymbol x) = \alpha\text{RMS}(\boldsymbol x) RMS(αx)=αRMS(x),因此,对于RMSNorm的通用形式:
y = f ( W x RMS ( a ) ⊙ g + b ) \boldsymbol y=f\bigg(\frac{W\boldsymbol x}{\text{RMS}(\boldsymbol a)}\odot\boldsymbol g+\boldsymbol b\bigg) y=f(RMS(a)Wxg+b)

有:
y ′ = f ( α W x RMS ( α W x ) ⊙ g + b ) = f ( W x RMS ( W x ) ⊙ g + b ) = y ′ \boldsymbol y' =f\bigg(\frac{\alpha W \boldsymbol x}{\text{RMS}(\alpha W \boldsymbol x)}\odot\boldsymbol g+\boldsymbol b\bigg) =f\bigg(\frac{W\boldsymbol x}{\text{RMS}(W x)}\odot\boldsymbol g+\boldsymbol b\bigg) =\boldsymbol y' y=f(RMS(αWx)αWxg+b)=f(RMS(Wx)Wxg+b)=y

3 效果


4 LayerNorm的重新中心化到底有没有用?

如下图所示,将网络权重均值随机初始化为0.2,LayerNorm收敛非常慢,但RMSNorm仍正常工作,可见 RMSNorm对权重初始均值的变化更加鲁棒

5 Torch源码

Qwen的实现,与T5 LayerNorm一致。

class Qwen2RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""Qwen2RMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)

http://www.mrgr.cn/news/89831.html

相关文章:

  • SpringBoot中的多环境配置管理
  • 微信开发者工具的快捷键
  • 校验收货地址是否超出配送范围实战3(day09)
  • extends配置项详解
  • firefox PAC代理
  • 探索 C++ 与 LibUSB:开启 USB 设备交互的奇幻之旅
  • 【从零开始系列】DeepSeek-R1:(本地部署使用)思维链推理大模型,开源的神!——Windows/Linux本地环境测试 + vLLM远程部署服务
  • k8s部署rabbitmq
  • 《Kettle实操案例一(全量/增量更新与邮件发送)》
  • 音频进阶学习十二——Z变换
  • 保姆级教程Docker部署KRaft模式的Kafka官方镜像
  • 【服务器知识】如何在linux系统上搭建一个nfs
  • 【Langchain学习笔记(二)】Langchain安装及使用示例
  • HIVE如何注册UDF函数
  • nodejs:express + js-mdict 网页查询英汉词典,能播放.spx 声音
  • Mac上搭建k8s环境——Minikube
  • 【非 root 用户下全局使用静态编译的 FFmpeg】
  • kafka服务端之延时操作实现原理
  • (一)C++的类与对象
  • Jmeter快速实操入门
  • docker安装es及分词器ik
  • 122,【6】buuctf web [护网杯2018] easy_tornado
  • 交叉编译工具链下载和使用
  • TaskBuilder项目实战:创建项目
  • 深入理解 DeepSeek MOE(Mixture of Experts)
  • 【戒抖音系列】短视频戒除-1-对推荐算法进行干扰