【开源项目】Excel手撕AI算法深入理解(三):时序(RNN、mamba)
项目源码地址:https://github.com/ImagineAILab/ai-by-hand-excel.git
一、RNN
1. RNN 的核心思想
RNN 的设计初衷是处理序列数据(如时间序列、文本、语音),其核心特点是:
-
隐藏状态(Hidden State):保留历史信息,充当“记忆”。
-
参数共享:同一组权重在时间步间重复使用,减少参数量。
2. RNN 的数学表达
对于一个时间步 t:
-
输入:xt(当前时间步的输入向量)。
-
隐藏状态:ht(当前状态),ht−1(上一状态)。
-
输出:yt(预测或特征表示)。
-
参数:权重矩阵 和偏置 。
-
激活函数:σ(通常为
tanh
或ReLU
)。
更新隐藏状态的核心操作
数学本质:非线性变换
-
At 是当前时间步的“未激活状态”,即隐藏状态的线性变换结果(上一状态 ht−1 和当前输入 xt 的加权和)。
-
tanh 是双曲正切激活函数,将 At 映射到 [-1, 1] 的范围内:
-
作用:引入非线性,使RNN能够学习复杂的序列模式。如果没有非线性,堆叠的RNN层会退化为单层线性变换。
梯度稳定性
-
tanhtanh 的导数为:
-
梯度值始终小于等于1,能缓解梯度爆炸(但可能加剧梯度消失)。
-
相比Sigmoid(导数最大0.25),tanhtanh 的梯度更大,训练更稳定。
3. RNN 的工作流程
前向传播
-
初始化隐藏状态 ℎ0h0(通常为零向量)。
-
按时间步迭代计算:
-
结合当前输入 xt 和上一状态 ht−1 更新状态 ht。
-
根据ht 生成输出 yt。
-
反向传播(BPTT)
通过时间反向传播(Backpropagation Through Time, BPTT)计算梯度:
-
沿时间轴展开RNN,类似多层前馈网络。
-
梯度需跨时间步传递,易导致梯度消失/爆炸。
4. RNN 的典型结构
(1) 单向RNN(Vanilla RNN)
-
信息单向流动(过去→未来)。
-
只能捕捉左侧上下文。
(2) 双向RNN(Bi-RNN)
-
两个独立的RNN分别从左到右和从右到左处理序列。
-
最终输出拼接或求和,捕捉双向依赖。
(3) 深度RNN(Stacked RNN)
-
多个RNN层堆叠,高层处理低层的输出序列。
-
增强模型表达能力。
5. RNN 的局限性
(1) 梯度消失/爆炸
-
长序列中,梯度连乘导致指数级衰减或增长。
-
后果:难以学习长期依赖(如文本中相距很远的词关系)。
(2) 记忆容量有限
-
隐藏状态维度固定,可能丢失早期信息。
(3) 计算效率低
-
无法并行处理序列(必须逐时间步计算)。
6. RNN 的代码实现(PyTorch)
import torch.nn as nnclass VanillaRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super().__init__()self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# x: [batch_size, seq_len, input_size]out, h_n = self.rnn(x) # out: 所有时间步的输出y = self.fc(out[:, -1, :]) # 取最后一个时间步return y
7. RNN vs. 其他序列模型
特性 | RNN/LSTM | Transformer | Mamba |
---|---|---|---|
长序列处理 | 中等(依赖门控) | 差(O(N2)) | 优(O(N)) |
并行化 | 不可并行 | 完全并行 | 部分并行 |
记忆机制 | 隐藏状态 | 全局注意力 | 选择性状态 |
8. RNN 的应用场景
-
文本生成:字符级或词级预测。
-
时间序列预测:股票价格、天气数据。
-
语音识别:音频帧序列转文本。
二、mamba
1. Mamba 的诞生背景
Mamba(2023年由Albert Gu等人提出)是为了解决传统序列模型(如RNN、Transformer)的两大痛点:
-
长序列效率问题:Transformer的Self-Attention计算复杂度为 O(N2),难以处理超长序列(如DNA、音频)。
-
状态压缩的局限性:RNN(如LSTM)虽能线性复杂度 O(N),但隐藏状态难以有效捕捉长期依赖。
Mamba的核心创新:选择性状态空间模型(Selective SSM),结合了RNN的效率和Transformer的表达力。
2. 状态空间模型(SSM)基础
Mamba基于结构化状态空间序列模型(S4),其核心是线性时不变(LTI)系统:
-
h(t):隐藏状态
-
A(状态矩阵)、B(输入矩阵)、C(输出矩阵)
-
离散化(通过零阶保持法):
其中
关键特性:
-
线性复杂度 O(N)(类似RNN)。
-
理论上能建模无限长依赖(通过HiPPO初始化 A)。
3. Mamba 的核心改进:选择性(Selectivity)
传统SSM的局限性:A,B,C 与输入无关,导致静态建模能力。
Mamba的解决方案:让参数动态依赖于输入(Input-dependent),实现“选择性关注”重要信息。
选择性SSM的改动:
-
动态参数化:
-
B, C, ΔΔ 由输入xt 通过线性投影生成:
-
- 这使得模型能过滤无关信息(如文本中的停用词)。
-
硬件优化:
-
选择性导致无法卷积化(传统SSM的优势),但Mamba设计了一种并行扫描算法,在GPU上高效计算。
-
4. Mamba 的架构设计
Mamba模型由多层 Mamba Block 堆叠而成,每个Block包含:
-
选择性SSM层:处理序列并捕获长期依赖。
-
门控MLP(如GeLU):增强非线性。
-
残差连接:稳定深层训练。
(示意图:输入 → 选择性SSM → 门控MLP → 输出)
Time-Varying Recurrence(时变递归)
作用
打破传统SSM的时不变性(Time-Invariance),使状态转移动态适应输入序列。
-
传统SSM的离散化参数 Aˉ,Bˉ 对所有时间步相同(LTI系统)。
-
Mamba的递归过程是时变的(LTV系统),状态更新依赖当前输入。
实现方式
-
离散化后的参数 Aˉt,Bˉt 由 Δt 动态控制:
-
-
Δt 大:状态更新慢(保留长期记忆)。
-
Δt 小:状态更新快(捕捉局部特征)。
-
-
效果:模型可以灵活调整记忆周期(例如,在文本中保留重要名词,快速跳过介词)。
关键点
-
时变性是选择性的直接结果,因为 Δt,Bt,Ct 均依赖输入。
Discretization(离散化)
作用
将连续时间的状态空间方程(微分方程)转换为离散时间形式,便于计算机处理。
-
连续SSM:
-
离散SSM:
实现方式
-
使用零阶保持法(ZOH)离散化:
总结
-
Selection:赋予模型动态过滤能力,是Mamba的核心创新。
-
Time-Varying Recurrence:通过时变递归实现自适应记忆。
-
Discretization:将连续理论落地为可计算的离散操作。
5. 为什么Mamba比Transformer更高效?
特性 | Transformer | Mamba |
---|---|---|
计算复杂度 | O(N2) | O(N) |
长序列支持 | 内存受限 | 轻松处理百万长度 |
并行化 | 完全并行 | 需自定义并行扫描 |
动态注意力 | 显式Self-Attention | 隐式通过选择性SSM |
优势场景:
-
超长序列(基因组、音频、视频)
-
资源受限设备(边缘计算)
6. 代码实现片段(PyTorch风格)
class MambaBlock(nn.Module):def __init__(self, dim):self.ssm = SelectiveSSM(dim) # 选择性SSMself.mlp = nn.Sequential(nn.Linear(dim, dim*2),nn.GELU(),nn.Linear(dim*2, dim)def forward(self, x):y = self.ssm(x) + x # 残差连接y = self.mlp(y) + y # 门控MLPreturn y
7. Mamba的局限性
-
训练稳定性:选择性SSM需要谨慎的参数初始化。
-
短序列表现:可能不如Transformer在短文本上的注意力精准。
-
生态支持:目前库(如
mamba-ssm
)不如Transformer成熟。