当前位置: 首页 > news >正文

【开源项目】Excel手撕AI算法深入理解(三):时序(RNN、mamba)

项目源码地址:https://github.com/ImagineAILab/ai-by-hand-excel.git

一、RNN

1. RNN 的核心思想

RNN 的设计初衷是处理序列数据(如时间序列、文本、语音),其核心特点是:

  • 隐藏状态(Hidden State):保留历史信息,充当“记忆”。

  • 参数共享:同一组权重在时间步间重复使用,减少参数量。

2. RNN 的数学表达

对于一个时间步 t:

  • 输入:xt​(当前时间步的输入向量)。

  • 隐藏状态:ht​(当前状态),ht−1​(上一状态)。

  • 输出:yt​(预测或特征表示)。

  • 参数:权重矩阵 和偏置 ​。

  • 激活函数:σ(通常为 tanh 或 ReLU)。

更新隐藏状态的核心操作

数学本质:非线性变换

  • At​ 是当前时间步的“未激活状态”,即隐藏状态的线性变换结果(上一状态 ht−1​ 和当前输入 xt​ 的加权和)。

  • ⁡tanh 是双曲正切激活函数,将 At​ 映射到 [-1, 1] 的范围内:

  • 作用:引入非线性,使RNN能够学习复杂的序列模式。如果没有非线性,堆叠的RNN层会退化为单层线性变换。

梯度稳定性

  • tanh⁡tanh 的导数为:

  • 梯度值始终小于等于1,能缓解梯度爆炸(但可能加剧梯度消失)。

  • 相比Sigmoid(导数最大0.25),tanh⁡tanh 的梯度更大,训练更稳定。

3. RNN 的工作流程

前向传播
  1. 初始化隐藏状态 ℎ0h0​(通常为零向量)。

  2. 按时间步迭代计算:

    • 结合当前输入 xt​ 和上一状态 ht−1​ 更新状态 ht​。

    • 根据ht​ 生成输出 yt​。

反向传播(BPTT)

通过时间反向传播(Backpropagation Through Time, BPTT)计算梯度:

  • 沿时间轴展开RNN,类似多层前馈网络。

  • 梯度需跨时间步传递,易导致梯度消失/爆炸

4. RNN 的典型结构

(1) 单向RNN(Vanilla RNN)
  • 信息单向流动(过去→未来)。

  • 只能捕捉左侧上下文。

(2) 双向RNN(Bi-RNN)
  • 两个独立的RNN分别从左到右和从右到左处理序列。

  • 最终输出拼接或求和,捕捉双向依赖。

(3) 深度RNN(Stacked RNN)
  • 多个RNN层堆叠,高层处理低层的输出序列。

  • 增强模型表达能力。

5. RNN 的局限性

(1) 梯度消失/爆炸
  • 长序列中,梯度连乘导致指数级衰减或增长。

  • 后果:难以学习长期依赖(如文本中相距很远的词关系)。

(2) 记忆容量有限
  • 隐藏状态维度固定,可能丢失早期信息。

(3) 计算效率低
  • 无法并行处理序列(必须逐时间步计算)。

6. RNN 的代码实现(PyTorch)

import torch.nn as nnclass VanillaRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super().__init__()self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# x: [batch_size, seq_len, input_size]out, h_n = self.rnn(x)  # out: 所有时间步的输出y = self.fc(out[:, -1, :])  # 取最后一个时间步return y

7. RNN vs. 其他序列模型

特性RNN/LSTMTransformerMamba
长序列处理中等(依赖门控)差(O(N2))优(O(N))
并行化不可并行完全并行部分并行
记忆机制隐藏状态全局注意力选择性状态

8. RNN 的应用场景

  • 文本生成:字符级或词级预测。

  • 时间序列预测:股票价格、天气数据。

  • 语音识别:音频帧序列转文本。

二、mamba

1. Mamba 的诞生背景

Mamba(2023年由Albert Gu等人提出)是为了解决传统序列模型(如RNN、Transformer)的两大痛点:

  1. 长序列效率问题:Transformer的Self-Attention计算复杂度为 O(N2),难以处理超长序列(如DNA、音频)。

  2. 状态压缩的局限性:RNN(如LSTM)虽能线性复杂度 O(N),但隐藏状态难以有效捕捉长期依赖。

Mamba的核心创新:选择性状态空间模型(Selective SSM),结合了RNN的效率和Transformer的表达力。

2. 状态空间模型(SSM)基础

Mamba基于结构化状态空间序列模型(S4),其核心是线性时不变(LTI)系统:

  • h(t):隐藏状态

  • A(状态矩阵)、B(输入矩阵)、C(输出矩阵)

  • 离散化(通过零阶保持法):

其中

关键特性

  • 线性复杂度 O(N)(类似RNN)。

  • 理论上能建模无限长依赖(通过HiPPO初始化 A)。

3. Mamba 的核心改进:选择性(Selectivity)

传统SSM的局限性:A,B,C 与输入无关,导致静态建模能力。
Mamba的解决方案:让参数动态依赖于输入(Input-dependent),实现“选择性关注”重要信息。

选择性SSM的改动:
  1. 动态参数化

    • B, C, ΔΔ 由输入xt​ 通过线性投影生成:

  1. 这使得模型能过滤无关信息(如文本中的停用词)。
  2. 硬件优化

    • 选择性导致无法卷积化(传统SSM的优势),但Mamba设计了一种并行扫描算法,在GPU上高效计算。

4. Mamba 的架构设计

Mamba模型由多层 Mamba Block 堆叠而成,每个Block包含:

  1. 选择性SSM层:处理序列并捕获长期依赖。

  2. 门控MLP(如GeLU):增强非线性。

  3. 残差连接:稳定深层训练。

(示意图:输入 → 选择性SSM → 门控MLP → 输出)

Time-Varying Recurrence(时变递归)

作用

打破传统SSM的时不变性(Time-Invariance),使状态转移动态适应输入序列。

  • 传统SSM的离散化参数 Aˉ,Bˉ 对所有时间步相同(LTI系统)。

  • Mamba的递归过程是时变的(LTV系统),状态更新依赖当前输入。

实现方式
  • 离散化后的参数 Aˉt​,Bˉt​ 由 Δt​ 动态控制:

    • Δt​ 大:状态更新慢(保留长期记忆)。

    • Δt​ 小:状态更新快(捕捉局部特征)。

  • 效果:模型可以灵活调整记忆周期(例如,在文本中保留重要名词,快速跳过介词)。

关键点
  • 时变性是选择性的直接结果,因为 Δt​,Bt​,Ct​ 均依赖输入。

Discretization(离散化)

作用

将连续时间的状态空间方程(微分方程)转换为离散时间形式,便于计算机处理。

  • 连续SSM:

  • 离散SSM:

实现方式
  • 使用零阶保持法(ZOH)离散化:

总结

  • Selection:赋予模型动态过滤能力,是Mamba的核心创新。

  • Time-Varying Recurrence:通过时变递归实现自适应记忆。

  • Discretization:将连续理论落地为可计算的离散操作。

5. 为什么Mamba比Transformer更高效?

特性TransformerMamba
计算复杂度O(N2)O(N)
长序列支持内存受限轻松处理百万长度
并行化完全并行需自定义并行扫描
动态注意力显式Self-Attention隐式通过选择性SSM

优势场景

  • 超长序列(基因组、音频、视频)

  • 资源受限设备(边缘计算)

6. 代码实现片段(PyTorch风格)

class MambaBlock(nn.Module):def __init__(self, dim):self.ssm = SelectiveSSM(dim)  # 选择性SSMself.mlp = nn.Sequential(nn.Linear(dim, dim*2),nn.GELU(),nn.Linear(dim*2, dim)def forward(self, x):y = self.ssm(x) + x          # 残差连接y = self.mlp(y) + y          # 门控MLPreturn y

7. Mamba的局限性

  • 训练稳定性:选择性SSM需要谨慎的参数初始化。

  • 短序列表现:可能不如Transformer在短文本上的注意力精准。

  • 生态支持:目前库(如mamba-ssm)不如Transformer成熟。


http://www.mrgr.cn/news/98726.html

相关文章:

  • 使用cursor进行原型图设计
  • 概念实践极速入门 - 常用的设计模式 - 简单生活例子
  • Flutter:图片在弹窗外部的UI布局
  • 一文掌握RK3568开发板Android13挂载Windows共享目录
  • vue3获取defineOptions的值;vue3获取组件实例;vue3页面获取defineOptions的name
  • 分布式热点网络
  • AI大模型学习九:‌Sealos cloud+k8s云操作系统私有化一键安装脚本部署完美教程
  • 集群搭建Weblogic服务器!
  • 《Against The Achilles’ Heel: A Survey on Red Teaming for Generative Models》全文阅读
  • 红宝书第四十七讲:Node.js服务器框架解析:Express vs Koa 完全指南
  • 前端基础之《Vue(5)—组件基础(1)》
  • Kubernetes(K8S)内部功能总结
  • 猫咪如厕检测与分类识别系统系列【六】分类模型训练+混合检测分类+未知目标自动更新
  • 【Vue】从 MVC 到 MVVM:前端架构演变与 Vue 的实践之路
  • shell 编程之正则表达式与文本处理器
  • centos7停服yum更新kernel失败解决办法
  • C++中变量、函数存储、包括虚函数多态实现机制说明
  • Deno 统一 Node 和 npm,既是 JS 运行时,又是包管理器
  • chili3d调试笔记2+添加web ui按钮
  • 基础学习:(6)nanoGPT