当前位置: 首页 > news >正文

【随笔】为什么transformer的FFN先升维后降维FFN的作用


Transformer 前馈神经网络(FFN)结构

在 Transformer 模型中,每个编码器和解码器层都包含一个前馈神经网络(FFN)模块。这个模块的结构特点是先将输入的维度提升,再缩小回去。具体来说,假设输入向量的维度为 d model d_{\text{model}} dmodel,前馈神经网络会将维度升高至 d ff d_{\text{ff}} dff,然后再降回 d model d_{\text{model}} dmodel。这一过程通常可表示为以下两个线性变换:

  1. 提升维度(Expand Dimension):第一个线性变换将输入向量的维度从 d model d_{\text{model}} dmodel 提升至 d ff d_{\text{ff}} dff,并应用激活函数。
  2. 缩小维度(Reduce Dimension):第二个线性变换将维度从 d ff d_{\text{ff}} dff 缩小回 d model d_{\text{model}} dmodel,使得输出与输入维度一致,便于下一层的处理。

FFN 公式

假设输入向量为 X \mathbf{X} X,FFN 的计算过程如下:

FFN ( X ) = ReLU ( X W 1 + b 1 ) W 2 + b 2 \text{FFN}(\mathbf{X}) = \text{ReLU}(\mathbf{X} \mathbf{W}_1 + \mathbf{b}_1) \mathbf{W}_2 + \mathbf{b}_2 FFN(X)=ReLU(XW1+b1)W2+b2

其中:

  • X ∈ R d model \mathbf{X} \in \mathbb{R}^{d_{\text{model}}} XRdmodel 是输入向量。
  • W 1 ∈ R d model × d ff \mathbf{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}} W1Rdmodel×dff 是第一个线性变换的权重矩阵。
  • b 1 ∈ R d ff \mathbf{b}_1 \in \mathbb{R}^{d_{\text{ff}}} b1Rdff 是第一个线性变换的偏置。
  • ReLU \text{ReLU} ReLU 是激活函数,应用在第一个线性变换的输出上。
  • W 2 ∈ R d ff × d model \mathbf{W}_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}} W2Rdff×dmodel 是第二个线性变换的权重矩阵。
  • b 2 ∈ R d model \mathbf{b}_2 \in \mathbb{R}^{d_{\text{model}}} b2Rdmodel 是第二个线性变换的偏置。

通过这种结构,前馈神经网络模块实现了输入-输出维度一致,但在中间层的神经元数量更多,具有更高的表达能力。

为什么要先提升维度后缩小维度?

这种 “先提升后缩小” 的设计有以下几个原因:

  1. 增强模型的非线性表达能力:中间的高维空间( d ff d_{\text{ff}} dff)允许 FFN 对输入数据进行更复杂的非线性变换,进而学习更复杂的模式。
  2. 捕捉特征的多样性:通过将维度提升至更高,FFN 可以更容易捕捉输入特征中潜在的细微差异,这在自然语言处理中尤其重要。
  3. 低维输入输出的一致性:尽管中间层的维度较高,最终输出回归至 d model d_{\text{model}} dmodel,使得 FFN 的输入和输出维度一致,方便后续层的处理和连接。
  4. 增强Transformer对于distributed的文本特征的组合能力,从而获取更多、更复杂的语义信息

Transformer模型中前馈层的作用是什么?

前馈层(FFN)在Transformer模型中的作用是对每个位置的词向量进行独立的非线性变换。尽管注意力机制能够捕捉序列中的全局依赖,但前馈层通过增加模型的深度和复杂度,为模型引入必要的非线性,从而增强模型的表达能力。每个编码器和解码器层都包含一个FFN,它对所有位置的表示进行相同的操作,但并不共享参数。

总结

Transformer 中的前馈神经网络通过先提升再缩小维度,实现了在相对较低输入维度( d model d_{\text{model}} dmodel)条件下提升网络的表达能力,同时保持输入输出的一致性。



http://www.mrgr.cn/news/62814.html

相关文章:

  • 一步到位Python Django部署,浅谈Python Django框架
  • P10打卡——pytorch实现车牌识别
  • LangChain学习笔记2 Prompt 模板
  • 【MySQL】SQL菜鸟教程(一)
  • C++语言的学习路线
  • 【解决】okhttp的java.lang.IllegalStateException: closed错误
  • 搜维尔科技:Manus数据手套在水下捕捉精确的手指动作, 可以在有水的条件下使用
  • 全面解析云渲染:定义、优势、分类与发展历程
  • java-参数传递与接收
  • 基于SSM+小程序的宿舍管理系统(宿舍1)
  • 【VM实战】VMware迁移到VirtualBox
  • 【c++篇】:模拟实现string类--探索字符串操作的底层逻辑
  • vite构建Vue3项目:封装公共组件,发布npm包,自定义组件库
  • 利用GATK对RNA-seq数据做call SNP 或 INDEL分析
  • VScode + PlatformIO 了解
  • 案例精选 | 石家庄学院大日志场景下的实名审计实践
  • Rust: 加密算法库 ring 如何用于 RSA 数字签名?
  • 罗马仕、西圣、安克充电宝哪款品牌更好?综合测评对比谁是TOP.1
  • 为Meta Spark准备3D模型
  • vue简介
  • 从0开始学习shell脚本
  • JS面试八股文(四)
  • windows环境下,使用docker搭建redis集群
  • java程序打包为一个exe程序
  • Python import package
  • [TypeError]: type ‘AbstractProvider‘ is not subscriptable