当前位置: 首页 > news >正文

RNN之:LSTM 长短期记忆模型-结构-理论详解(Matlab向)

0.前言

递归!循环神经网络Recurrent Neural Network

循环神经网络(又称递归神经网络,Recurrent Neural Network,RNN)。是一种用于处理序列数据的神经网络结构,具有记忆功能,能够捕捉序列中的时间依赖关系。RNN通过循环连接的方式,将当前时间步的输入和上一时间步的隐藏状态作为输入,计算当前时间步的隐藏状态和输出,从而实现对序列数据的建模和预测。

说到递归描述,各位肯定会想到马尔科夫链,循环神经网络和马尔科夫链具有一定的相似性,两者都是用来捕捉和描述序列数据的变化规律。

但两者在其原理上有根本区别:

马尔可夫链需要满足马尔可夫状态,即知道t时刻的状态时,t+1时刻的状态的概率分布明确。我们可以通过转移矩阵基于t时刻的信息计算t+1、t+2等未来时刻中不同状态的预期概率。因此未来状态仅受当前状态影响,过去发生过什么并不重要。因此它可以常常用来描述理想状态下物理场中概率变化目标的属性。而循环神经网络则不需要满足任何前提条件,它可以基于过去任何长短时间的数据,基于统计经验给出未来的最可能状态。

但是相较于循环神经网络而言,马尔可夫链在实际应用中具有天生缺陷:

  • 是马尔可夫链的转移矩阵的概率映射是线性的,并不能很好的描述非线性的长迭代状态(比如,转移矩阵会随时间步长发生变化)。所以马尔可夫链在自然语言中的处理和应用非常拉跨。对于RNN而言,由于其本身结构的复合性,RNN的状态转移可以满足非线性的变化需求。
  • 同时马尔可夫状态的设定过于理想,现实世界的应用工程中,状态空间本身,转移矩阵的难以精确求解。而对于RNN而言,无论满足满足马尔可夫状态,都可以从递归过程中挖掘统计规律。因此在一些具有统计学特性外的外部附加规则的预测任务中(比如游戏抽卡),RNN能够反应与时间步长相关的特征。

然而,对于传统的RNN,随着序列长度的增加,计算得到的梯度在反向传播过程中会逐渐消失或爆炸,导致模型难以训练。这种现象被称为“梯度消失”或“梯度爆炸”,它限制了RNN捕捉和利用长距离依赖关系的能力。这对于许多需要理解长期依赖关系的任务,如自然语言处理、语音识别等,是一个严重的挑战。

LSTM的诞生

1997年,由德国慕尼黑工业大学的计算机科学家Sepp Hochreiter与Jürgen Schmidhuber共同提出了LSTM(Long Short-Term Memory)模型。LSTM是一种特别的RNN,旨在解决传统RNN在处理长序列数据时遇到的梯度消失和梯度爆炸问题。

原文链接:Long Short-Term Memory | MIT Press Journals & Magazine | IEEE Xplore

原文PDF:nc.dvi

Sepp Hochreiter与Jürgen Schmidhuber在实验中对LSTM和之前其他RNN模型做了进一步对比:利用无噪声和有噪声的序列对不同的RNN进行训练,这些序列长度较长,对于有噪声序列而言,其中只有少数数据是重要的,而其余数据则起到干扰作用。

表:LSTM与RTRL(RealTime Recurrent Learning)、ELM(Elman nets,又称单循环网络)和RCC(Recurrent Cascade-Correlation,级联相关学习架构)在长序列无噪声数据中的表现对比

可以发现,相较于RTRL、ELM、RCC等传统模型,LSTM在长序列无噪声模型中的训练成功率更高,同时训练速度更快。

注:表中LSTM的Block为一个LSTM单元,Block为4代表4层LSTM单元进行堆叠,Size代表一个Block或LSTM单元中memory cell(记忆单元,也就是长期记忆)并列的数量,为了方便计算,这些记忆单元共享输入门和输出门和遗忘门。其实,文中所说的Size其实就是我们现在说的隐藏单元数量,由于1997年计算机技术不是很发达,GPU并行计算的运用不多,所以原文的结构写得很复杂。(具体示例结构如原文图2所示,由于是编码训练的形式,原文展示得很复杂,看不懂没关系,不影响后面理解)。

其中,原文对Size的表述是:Memory cell blocks of size S form a structure where S memory cells share the same input gate and the same output gate. These blocks facilitate information storage. However, as with conventional neural nets, it is not so easy to code a distributed input within a single cell. Since each memory cell block has as many gate units as a single memory cell, namely two, the block architecture can be even slightly more efficient (see paragraph on computational complexity in Section).

表:LSTM与RTRL、BPTT(Back-Propagation Through Time)和RCC在长序列无局部规律性数据中的表现对比

由于长期记忆(细胞记忆)的特性,在无局部规律数据中,LSTM的表现相较于RTRL、BPTT和CH模型在训练速度及跟踪能力上有质的差异,不仅训练迭代需求缩小了十倍以上,而且表现出极好的追踪能力。这也是为什么在那么多RNN模型中,LSTM可以获得成功的原因。

简单来说而言:LSTM(Long Short-Term Memory)是一种特殊的递归神经网络(RNN),旨在解决传统RNN在处理长序列数据时遇到的梯度消失和梯度爆炸问题。LSTM通过引入“记忆单元”和“门控机制”,能够高效建模长期和短期依赖关系。

1.长短期记忆网络 LSTM(Long Short-Term Memory)的基本结构

常规LSTM单元的结构如图1所示,在一些论文中常被为“记忆细胞”(Memory Cell),但在中文语境我更倾向于称之为LSTM单元,这样不容易产生混淆。

这是由于GPU的推广,目前的LSTM设计类似于多个“传统的记忆细胞”并联在一起,组成一个集合体,对应LSTM开山论文中的“Cell block”,来传输多维的记忆信息,而在开山论文中,Memory cell代表只能处理数值而非多维向量的单元,因此在当前时代,Memory cell和Cell block在很多语境下基本上是一个东西,所以这个cell究竟是那个cell?这个描述不是很严谨。

一般而言,这就是,当其在处理数据时,数据内容在内部循环流动,序列数据或任何长度的单维度数据(dim=1)在输入LSTM时,按照不同的时间步长t,依次输入模型,在这一过程中当前的细胞状态(cell sate)和隐藏状态(hidden state)在不断变化。

它包含三个关键的门控机制:输入门(Input Gate)、遗忘门(Forget Gate)和输出门(Output Gate)。其详细信息如下图所示:

图1:LSTM单元经典结构细节(一图流),以下介绍与改图一一对应

图中:

  • 为了简化表达,不同位置weight、Bias代表不同权重乘法和偏置加法操作,尽管名字相同,但不同位置的weight、Bias的参数相互独立
  • “()”括号内的数值代表在该环节传输或该参数的“张量”大小,如Weight(N*N)代表权重矩阵的张量规模是N*N。
  • 短期记忆(隐藏状态 h_t)和长期记忆(细胞状态 C_t)的张量尺度必须一致,均为 N 维向量,N是可以自己需求设计数值,称为隐藏单元的数量

1.1.遗忘门(Forget Gate)

遗忘门决定了前一时刻的记忆细胞状态中有多少信息需要保留或遗忘。

它接收当前的输入X_t和上一时间步的隐藏状态h_{t-1}​,通过Sigmoid函数输出一个0到1之间的值,表示遗忘的比例。

计算公式为:

f_t=sigmoid(W_{f1}\cdot h_{t-1}+W_{f2} \cdot X_t+b_f)                    (1)

其中W_{f}b_f​是遗忘门的权重和偏置的参数矩阵,为了区分对应隐藏状态和当前输入的权重矩阵相互独立,格外加了数字下表如:W_{f1},后同。

1.2.输入门(Input Gate)

输入门决定了当前时间步的输入信息中有多少需要被加入到细胞状态中。

如一图流所示,它包含两个部分:首先,使用Sigmoid函数决定哪些信息需要更新(Potential Memory to remember),i_t;其次,使用Tanh函数生成一个候选长期(Potential Long-Trem menmory),也就是\tilde{C}_t

计算公式为:

i_t=sigmoid(W_{i1}\cdot h_{t-1}+W_{i2} \cdot X_t+b_{i1})                                 (2)

\tilde{C}_t=tanh(W_{i3}\cdot h_{t-1}+W_{i4} \cdot X_t+b_{i2})                                        (3)

1.3.细胞状态-长期记忆(Cell-State)

        Cell-State负责在序列的各个时间步长之间存储和传递信息,简单的来说,它就是长期记忆。它像是一个传输带,在整个序列处理过程中,信息在上面流动,只有少量的线性操作被应用于信息本身,信息流动相对简单。

每次迭代中,细胞状态更新公式为:

C_t=C_{t-1}\cdot f_t +i_t\cdot \tilde{C}_t                                                  (4)

根据式(1)、(2)、(3),(4)可完整展开为:

C_t=C_{t-1}\cdot sigmoid(W_{f1}\cdot h_{t-1}+W_{f2} \cdot X_t+b_f)+sigmoid(W_{i1}\cdot h_{t-1}+W_{i2} \cdot X_t+b_{i1})\cdot tanh(W_{i3}\cdot h_{t-1}+W_{i4} \cdot X_t+b_{i2})        (5)

        其中,C_t为该时间步输出的细胞状态,C_{t-1}为上一时间步输出的的细胞状态,其中f_t由遗忘门计算而出,决定了上一刻信息保留的程度。i_t\cdot \tilde{C}_t是选择门从隐藏状态与当前时间步或取信息产生的。总结而言:遗忘门决定了前一时刻的记忆细胞状态中有多少信息需要保留,输入门决定了当前时间步的输入信息中有多少需要被加入到细胞状态中。

1.4.隐藏状态-短期记忆(Hidden State)与输出门(Output Gate)

        输入门决定了基于当前时间步的细胞状态、输入数据和上一时间步的隐藏状态,输出当前时间步隐藏状态h_t,对于序列预测,ht就是LSTM对于下一时间步输入 X_{t+1}的预测。如果你是以序列输出且以回归损失函数反向传播的。

它接收当前的输入xt​和上一时间步的隐藏状态ht−1​,通过Sigmoid函数输出一个0到1之间的值,表示输出的比例。

计算公式为:

h_t=tanh(W_{i3}\cdot h_{t-1}+W_{i4} \cdot X_t+b_{i2})\cdot tanh(C_t)                      (6)

  同样可以完整展开为: 

   h_t=tanh(W_{i3}\cdot h_{t-1}+W_{i4} \cdot X_t+b_{i2})\cdot tanh(C_{t-1}\cdot sigmoid(W_{f1}\cdot h_{t-1}+W_{f2} \cdot X_t+b_f)+sigmoid(W_{i1}\cdot h_{t-1}+W_{i2} \cdot X_t+b_{i1})\cdot tanh(W_{i3}\cdot h_{t-1}+W_{i4} \cdot X_t+b_{i2}))   (7)

2.长短期记忆网络 LSTM(Long Short-Term Memory)的具体设计细节及应用注意事项

2.1.LSTM的参数

LSTM的参数量计算

如基本结构介绍及一图流中所示,LSTM的模型参数主要包括权重和偏置,其中权重根据计算对象分为输入参数权重(Input Weight)和针对隐藏状态的权重(Recurrent Weight),偏置则对应遗忘门、输入门和输出门中的计算。具体信息可以参照一图流,里面已经标的很清楚了

我们来简单地举个例子,下图是我设计的一个简单的LSTM模型,其中输入为长度807的序列数据,LSTM隐藏单元的大小被设置成4000,因此HiddenState和Cell State的大小都为4000

图二

根据图一的展示,对应图二我们可以简单的计算得到

输入参数的权重参数数量为:

 4(输入接口)*[4000*1](单个权重矩阵大小)*[807*3](输入数据维度=所有空间长度累乘*通道长度)=16000*2421=38736000;

隐藏状态的权重参数数量为:

   4(输入接口)*[4000*4000](单个权重矩阵大小)=16000*4000=64000000;

偏置参数的数量为:

4(四个处理点位)*[4000*1](单个偏置矩阵大小)=16000*1=16000;

总参数为:16000+64000000+38736000= 102752000

注:在matlab中,LSTM的设计总是对象单一时间步的,比如这里,当我的输入具有空间属性时,LSTM会先将其flatten至序列再进行处理,所以输入参数的数量就是:所有空间长度累乘*通道长度。并不是你输入一个序列长度为100(S)通道为3的数据,程序就建立输入维度为3的模型,在计算中迭代100次。而是输入维度为300,不限迭代次数模型。

在设计LSTM的输入时,一般不包含空间属性,只有通道属性、批次、和时间步,也就是CBT。在计算参数的例子中,故意加入了空间属性,也是为了说明这一特性。

当LSTM输入图像时

2.2.状态激活方法 StateActivation Fuction 与 门激活方法 GateActivation Function

在LSTM(长短期记忆网络)中,状态激活方法(State Activation Function)和门激活方法(Gate Activation Function)起着至关重要的作用。它们分别用于生成候选记忆状态和控制信息的流动。虽然默认情况下,状态激活方法使用tanh函数,门激活方法使用sigmoid函数,但在实际应用中,这些激活函数可以根据具体需求进行调整和替换。

状态激活方法(State Activation Function)

默认激活函数:tanh

tanh函数(双曲正切函数)将输入映射到-1到1之间,其输出均值为0,这有助于保持梯度在传播过程中相对稳定,从而在一定程度上缓解梯度消失问题。

softsign

softsign函数是另一种非线性激活函数,其输出范围也是-1到1,但与tanh函数相比,softsign函数在输入接近0时具有更大的梯度,这有助于模型在训练初期更快地收敛。此外,softsign函数在输入绝对值较大时,其输出趋近于±1的速度更慢,这有助于避免梯度爆炸问题,以改善模型的收敛速度和稳定性。

ReLU(Rectified Linear Unit)

ReLU函数是深度学习中常用的激活函数,其输出在输入为正数时为输入本身,在输入为负数时为0。ReLU函数具有计算简单、非饱和性的特点,能够有效缓解梯度消失问题。尽管ReLU函数在LSTM中不常用作状态激活函数,但在某些特定的应用场景下,如处理稀疏特征或加速训练过程时,可以尝试使用ReLU函数。然而,需要注意的是,由于ReLU函数的输出范围不是-1到1,直接使用ReLU函数可能会破坏LSTM中记忆细胞的0中心化特性,因此在实际应用中需要进行适当的调整。

门激活方法(Gate Activation Function)

默认激活函数:sigmoid

特点:sigmoid函数将输入映射到0到1之间,其输出值可以解释为信息通过门的概率。sigmoid函数在输入较小或较大时,其输出趋近于0或1,这有助于实现门控机制,即控制信息的流动。

hard-sigmoid

hard sigmoid函数作为sigmoid函数的近似,计算速度更快。

2.3.Bi-LSTM 

Bi-LSTM特别简单,其结构就是两个LSTM模块镜像并列,然后方向相反,因此其有两条流向相反的细胞状态和两条流向相反的隐藏状态,因此其参数量的计算就是简单的LSTM*2即可(隐藏单元数量相同时),读者可以对应图二和图三进行领悟,这里就没必要展开来讲了。

 图三

2.4.LSTM的输入和输出

LSTM及Bi-LSTM的输入一般为序列数据,格式通常为(CBT),即一个通道度、批次度和时间度。在2.1的示例中,我们可以发现LSTM的单次输入总会铺平成向量。

对于输出长度为K的向量X时,其对应输入的权重转化矩阵W为N*K,N为隐藏单元数量,在计算权重时,W在前,输入向量X在后,进行矩阵乘法,W×X,输出的参数大小正好是N*1,对应隐藏状态的大小,使得数据对齐。

LSTM的输出比较重要,分为序列状态输出(sequence output)和最后状态输出(last output),注意!无论那种输出,它们输出的主要是隐藏状态信息,而不是细胞状态信息。

输出为隐藏状态主要是LSTM的设计考量,一是隐藏状态就是LSTM对下一个时间步信息的预测信息,二是隐藏状态在不同时间步的变化更大,含有更多信息,在序列输出时也能够保留更多的特征细节,读者可以对照式5和式7进行理解。

从应用层面而言(通常):

序列状态输出常用于序列到序列的任务,如机器翻译、信号转化等任务,其中每个时间步的输出都对应着输出序列中的一个元素。

最后状态输出常用于序列到标签的任务,如文本分类、情感分析等,其中LSTM的输出用于预测整个序列的类别或标签。回归预测同样如此。

其实无论是序列状态输出还是最后状态输出,在分类和回归问题都可以使用,比如我堆叠三层LSTM,前两层用序列输出的LSTM,最后一层用序列输出的LSTM,或者我直接序列输出到MPL都是可以的。LSTM的设计还是很自由的,各位只需要知道序列输出的信息更多即可。

2.5.LSTM特别的偏置初始化:unit-forget-gate

unit-forget-gate作为偏置初始化策略时,这通常意味着对遗忘门的偏置进行特定的初始化,以鼓励或抑制遗忘门在训练初期的行为。例如,在某些情况下,可能会将遗忘门的偏置初始化为一个较大的正值,以使得模型在训练开始时倾向于保留更多的信息(因为遗忘门的输出接近于1),这有助于模型更稳定地学习长期依赖关系。


http://www.mrgr.cn/news/83376.html

相关文章:

  • 机器学习基础-机器学习的常用学习方法
  • RK3568 Android 13 内置搜狗输入法小计
  • Mysql--基础篇--多表查询(JOIN,笛卡尔积)
  • ffmpeg7.0 合并2个 aac 文件
  • 信息科技伦理与道德3:智能决策
  • 汽车供应链关键节点:物流采购成本管理全解析
  • win32汇编环境,怎么进行乘法运算的
  • 测试开发之面试宝典
  • 01 springboot集成mybatis后密码正确但数据库连接失败
  • JVM与Java体系结构
  • SQL从入门到实战-2
  • 【华为云开发者学堂】基于华为云 CodeArts CCE 开发微服务电商平台
  • Mysql进阶篇
  • 01 Oracle自学环境搭建
  • Lambda expressions in C++ (C++ 中的 lambda 表达式)
  • L1G5000 XTuner 微调个人小助手认知
  • Microsoft 已经弃用了 <experimental/filesystem> 头文件
  • 力扣算法题(基于C语言)
  • 2025年第三届“华数杯”国际赛B题解题思路与代码(Python版)
  • Qt学习笔记第81到90讲
  • 油猴支持阿里云自动登陆插件
  • SpringBoot3
  • java开发springoot
  • 金融项目实战 02|接口测试分析、设计以及实现
  • 鼠标自动移动防止锁屏的办公神器 —— 定时执行专家
  • 【traefik】forwadAuth中间件跨namespace请求的问题