当前位置: 首页 > news >正文

SageAttention2

Paper: https://arxiv.org/abs/2411.10958

https://github.com/thu-ml/SageAttention

“SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization”由Jintao Zhang等人撰写。文章提出SageAttention2,通过线程级INT4量化、Q矩阵平滑、两级累加策略等技术,在提升注意力计算效率的同时保持精度,在多种模型上取得优异性能。

0.简介

  1. 研究背景:随着序列长度增加,注意力机制的二次时间复杂度使其高效实现变得关键。现有优化方法各有局限,如线性和稀疏注意力方法适用范围有限,常用的注意力方法如FlashAttention、xformers和SageAttention等虽有不错性能,但SageAttention存在INT8矩阵乘法速度慢和特定GPU加速受限的问题。
  2. 相关工作
    • FlashAttention:将注意力计算中的矩阵按token维度分块并行计算,降低计算复杂度,提升计算效率。
    • 量化:通过将高精度矩阵转换为低精度格式加速矩阵乘法,不同量化器在数值格式和粒度上有差异。
    • SageAttention:基于FlashAttention的分块策略,将Q、K量化为INT8,对K进行预处理以保持精度,对(\tilde{P})、V使用FP16并降低累加器精度加速计算,但存在局限性。
  3. SageAttention2方法
    • 平滑Q:由于INT4数值范围有限,存在异常值影响量化精度。通过减去Q每个块的均值平滑Q,结合对K的平滑,将(QK^{\top})计算分解,分预处理和注意力两个阶段,提升INT4量化精度。
    • INT4线程级量化:在SageAttention的基础上,提出线程级量化,根据GPU线程和矩阵内存布局,以更细粒度进行量化,避免额外去量化开销,提升精度。
    • (\tilde{P}V)的FP8量化:鉴于(\tilde{P})的分布特点,将(\tilde{P})、V量化为FP8(E4M3),采用静态量化和按通道量化,在保持精度的同时利用GPU张量核心加速计算。
    • 针对FP22累加器的FP32 MMA缓冲区:因实际CUDA实现中FP8矩阵乘法累加器为FP22导致精度损失,采用两级累加策略,用FP32缓冲区累加FP22值,还提出可选的平滑V技术提升精度。
  4. 实验
    • 实验设置:在多种语言、图像和视频生成模型上进行实验,对比SageAttention2与多种基线方法,使用不同数据集和指标评估。
    • 内核速度和精度:SageAttention2在RTX4090上比FlashAttention2和xformers快约3倍和4.5倍,在Hopper GPU上与FlashAttention3(fp8)速度相当但精度更高,在CogvideoX模型上精度优于其他基线方法。
    • 端到端性能:SageAttention2在多种模型上保持端到端指标,可视化结果显示其生成的图像和视频质量高,且能显著加速模型,如在CogvideoX (1.5 - 5B)上实现1.8倍加速且无指标损失。
    • 消融实验:线程级量化、平滑Q和两级累加技术的开销分别为0.35%、3.7%和0%,平滑V可提升精度,但在部分模型中无明显效果。
  5. 研究结论:SageAttention2是一种高效且准确的量化注意力机制,通过创新量化方法和精度提升技术,在速度和精度上优于多种现有方法,在不同类型模型中保持端到端性能,为加速注意力计算提供有效方案。

作用

SageAttention2 是清华大学陈键飞团队提出的高效注意力计算框架,其核心作用是通过低比特量化与硬件优化技术,显著提升注意力计算效率,同时保持模型精度。以下是其具体作用与技术实现:

SageAttention2是一种全新的即插即用注意力模块,其作用主要包括以下几个方面:

  • 加速推理速度:采用4 - Bit量化技术,在多种硬件平台上实现了显著的推理加速。例如,在RTX4090上较FlashAttention2推理速度提升三倍,在A100上提升至1.6倍,在L20、L40、L40S上可以实现2倍的加速,为多样化环境中的AI模型部署提供了可能。
  • 保持模型精度:通过对Q、K矩阵进行平滑处理,以及引入Per - thread量化方法等技术手段,克服了低比特量化常见的精度损失问题,在多种大型模型应用中保持了端到端的精度表现,确保了模型的多样性与稳定性。
  • 支持多种应用场景:有助于促进AI绘画、视频生成、文本生成等多种应用场景的落地。以开源视频生成模型CogvideoX - 1.5 - 5B为例,采用SageAttention2后,其端到端的推理速度提升达1.8倍,且在视频生成效果上无损失。

使用

from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)

根据提供的代码和搜索结果,以下是关于sageattn函数的详细说明:

  1. 函数功能
    sageattn是SageAttention库提供的注意力计算函数,支持FP16/BF16精度的输入,并针对不同GPU架构(如Ampere、Ada、Hopper)进行了优化。

  2. 参数说明

  • 输入张量q, k, v需为FP16或BF16类型,支持两种形状:
    ◦ 默认布局tensor_layout="HND":形状为(batch_size, head_num, seq_len, head_dim)
    ◦ 布局tensor_layout="NHD":形状为(batch_size, seq_len, head_num, head_dim)
  • is_causal:布尔值,决定是否使用因果掩码(如解码器的自回归场景)。
    因果掩码确保模型在生成序列时只能访问当前及之前的位置信息,而无法“看到”未来的信息。这种掩码通常是一个​​下三角矩阵​​(主对角线及以下为0,其余为负无穷-inf),通过以下方式实现:
    ​​训练阶段​​:模拟自回归生成过程,防止模型作弊(如依赖未来信息)。
    ​​推理阶段​​:保证生成的每个词仅依赖已生成的词,维持时序逻辑

Available APIs​​

(可用API列表 | 黑色背景白字排版内容)

  • sageattn​​
    Automatically selects the optimal kernel based on the GPU to achieve a good performance-accuracy trade-off.
    功能:根据GPU自动选择最优计算内核,平衡性能与精度

  • ​​sageattn_qk_int8_pv_fp16_triton​​
    INT8 quantization for QKᵀ and FP16 for PV using Triton backend.
    量化方案:QKᵀ用INT8,PV用FP16(Triton后端)

  • sageattn_qk_int8_pv_fp16_cuda​​
    INT8 quantization for QKᵀ and FP16 for PV using CUDA backend.
    量化方案:QKᵀ用INT8,PV用FP16(CUDA后端)

  • sageattn_qk_int8_pv_fp8_cuda​​
    INT8 quantization for QKᵀ and FP8 for PV using CUDA backend.
    量化方案:QKᵀ用INT8,PV用FP8(通用CUDA)

  • sageattn_qk_int8_pv_fp8_cuda_sm90​​
    INT8 quantization for QKᵀ and FP8 for PV using CUDA backend, specifically optimized for Hopper GPUs.
    硬件优化:专为Hopper架构GPU(如H100)设计

  • sageattn_varlen​​
    INT8 quantization for QKᵀ and FP16 for PV using Triton backend. Support for varying sequence lengths within the same batch.
    特殊功能:支持同批次变长序列处理

  • sageattention/core.py
    get_cuda_arch_versions
    sageattn
    sageattn_qk_int8_pv_fp16_triton
    sageattn_varlen
    sageattn_qk_int8_pv_fp16_cuda
    sageattn_qk_int8_pv_fp8_cuda
    sageattn_qk_int8_pv_fp8_cuda_sm90

Plug-and-play Example

添加以下代码

import torch.nn.functional as Ffrom sageattention import sageattn
F.scaled_dot_product_attention = sageattn
# 1. 环境
查看GPU硬件信息

nvcc --version # 检查CUDA Toolkit版本(需≥12.4)
nvidia-smi # 验证驱动兼容性(若驱动版本高于Toolkit,需更新CUDA至匹配版本)

RTX 3090 采用的是 NVIDIA Ampere 架构,这是 NVIDIA 的第 2 代 RTX 架构
RTX 4090 基于 NVIDIA 的 Ada Lovelace 架构
# 2. 安装 SageAttention
```shell
# 构建一个虚拟环境
conda create -n env1 python=3.12.3 
# 更新bashrc中的环境变量
conda init bash && source /root/.bashrc# 切换到创建的虚拟环境:my-env
conda activate my-env

requirements.txt

torch==2.5.1
torchvision==0.20.1
torchaudio==2.5.1
triton==3.1.0

pip install -r requirements.txt

安装ModelScope

pip install modelscope

直接安装sageA

pip install sageattention //默认版本1.0.6
pip uninstall sageattention

编译

git clone https://github.com/THUDM/SageAttention2.git
cd SageAttention2# 编译 CUDA 内核(需确保 CUDA 环境已配置)
python setup.py install# 重新编译
# 清除旧编译文件
cd SageAttention
rm -rf build/
# 重新编译并保存日志到文件
python setup.py install --record install.log > compile_output.log 2>&1
# 查看日志
cat compile_output.log

验证安装

import sageattention as sa2
print(sa2.__version__)  # 应输出版本号(如 0.1.0)
# or
python -c "import sageattention; print(sageattention.__version__)"
# 查看 sageattention 是否在正确的 Python 环境路径中
python -c "import sageattention; print(sageattention.__file__)"

ImportError: cannot import name ‘_fused’ from partially initialized module ‘sageattention’ (most likely due to a circular import) (…/sageattention/init.py)

sageattention/init.py 中直接导入了多个 CUDA 扩展模块(如 sageattn_qk_int8_pv_fp8_cuda),但这些模块可能依赖于 _fused 模块。若 _fused 未正确编译或存在循环导入,会导致 ImportError。

3.集成到 Llama 模型

替换 注意力模块
参考:https://blog.csdn.net/gitblog_00216/article/details/146898493

import torch.nn.functional as F
from sageattention import sageattn# 替换原有的注意力机制
F.scaled_dot_product_attention = sageattn

4. 性能优化与测试

4.1 推理加速

input_text = "中国的首都是"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
# 原始推理速度测试
%%timeit
outputs = model.generate(**inputs, max_length=50)
# SageAttention2 加速后测试
%%timeit
outputs = model.generate(**inputs, max_length=50)

预期效果:速度提升约 2-4 倍(取决于 GPU 型号和量化配置)。

4.2 显存优化
INT4 量化:Q/K 矩阵显存占用减少至原版的 1/4。

FP8 保留:P/V 矩阵显存占用减少至原版的 1/2。

5. 注意事项

6. 故障排查

CUDA 内核编译失败:

  • 检查 CUDA 版本与 PyTorch 是否匹配。

  • 确保 ninja 已安装:sudo apt-get install ninja-build

量化误差过大:

  • 尝试启用 SageAttention2 的 mean_smoothing 参数(见官方文档)。

通过上述步骤,您可以在 Llama 模型中无缝集成 SageAttention2,显著提升推理速度并降低显存占用。如需进一步调优,可参考清华大学提供的完整文档。

安装深度学习框架(PyTorch/TensorFlow)​

​​PyTorch 安装​​(推荐):

# 根据 CUDA 版本选择(若需 GPU 支持)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121  # CUDA 12.1
# 或 CPU 版本
pip install torch torchvision torchaudio

引用说明:PyTorch 是 Transformers 的主要依赖库,需优先安装。

​​TensorFlow 安装​​(可选):
bash
pip install tensorflow

安装 huggingface transformers

# pip
pip install transformers
# 若需同时安装数据集和评估工具
pip install transformers datasets evaluate
# or
git clone https://github.com/huggingface/transformers.git
cd transformers
pip install .

​安装扩展功能​​
​​支持语音/视觉任务​​(需额外依赖):

pip install transformers[sentencepiece]  # 支持 T5、ALBERT 等分词
pip install transformers[torchaudio]     # 音频处理支持
pip install transformers[vision]        # 图像处理支持

使用国内镜像源(如阿里云、清华源)解决下载问题:

pip install transformers -i https://mirrors.aliyun.com/pypi/simple/

​​分词器安装 sentencepiece 或 tokenizers:pip install sentencepiece

GPU架构及型号

SUPPORTED_ARCHS = {“8.0”, “8.6”, “8.9”, “9.0”, “12.0”} 对应哪些型号gpu
以下是与 SUPPORTED_ARCHS 中各计算能力对应的部分 NVIDIA GPU 型号:

  1. Compute Capability 8.0 (Ampere 架构)
    GPU 型号:
    A100(数据中心/专业计算卡),A30(数据中心推理卡)

特点:
专为 AI 训练和高性能计算设计,支持 Tensor Cores 和 FP64 高性能计算。

  1. Compute Capability 8.6 (Ampere 架构, 消费级 GPU)
    GPU 型号:
    GeForce RTX 30 系列:RTX 3080、RTX 3090、RTX 3070 Ti 等
    RTX A6000(专业工作站显卡)
    特点:
    针对游戏和创作优化,支持 DLSS 和光线追踪。

  2. Compute Capability 8.9 (Ada Lovelace 架构)
    GPU 型号:
    GeForce RTX 40 系列:RTX 4090、RTX 4080、RTX 4070 Ti 等
    RTX 6000 Ada(专业工作站显卡)
    特点:
    新一代架构,显著提升光线追踪和 AI 性能(如 DLSS 3.0)。

  3. Compute Capability 9.0 (Hopper 架构)
    GPU 型号:
    H100(数据中心/超大规模 AI 模型训练)
    H200(升级版 HBM3 显存)
    特点:
    专为 AI 和超级计算设计,支持 Transformer Engine 和 FP8 精度加速。

  4. Compute Capability 12.0 (Blackwell 架构)
    GPU 型号(2024 年发布):
    B100、B200(下一代数据中心 GPU)
    GB200 Grace-Blackwell Superchip(集成 CPU+GPU 的超算方案)
    特点:
    针对生成式 AI 和大模型优化,显著提升计算密度和能效。

安装docker

Ubuntu/Debian

# 1. 卸载旧版本(如有)
sudo apt-get remove docker docker-engine docker.io containerd runc# 2. 更新包索引并安装依赖
sudo apt-get update
sudo apt-get install -y apt-transport-https ca-certificates curl gnupg lsb-release# 3. 添加 Docker 官方 GPG 密钥
sudo install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
sudo chmod a+r /etc/apt/keyrings/docker.gpg# 4. 添加 Docker 仓库
echo \"deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \sudo tee /etc/apt/sources.list.d/docker.list > /dev/null# 5. 安装 Docker
sudo apt-get update
sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin# 6. 验证安装
sudo docker run hello-world# 7. 将当前用户加入 docker 组(避免每次输入 sudo)
sudo usermod -aG docker $USER
# 退出终端重新登录后生效

ModelScope 下载模型

https://modelscope.cn/docs/models/download

模型下载到默认的cache_dir目录中:~/.cache/modelscope/hub/models
更改:

# 在用户profile中添加
echo "export MODELSCOPE_CACHE=~/.modelscope_cache" >> ~/.bashrc
  1. 命令行下载
    下载整个模型repo到指定目录
    modelscope download --model 'Qwen/Qwen2-7b' --local_dir 'path/to/dir'
  1. 使用 GIT 下载模型
# 公开模型下载
git lfs install
git clone https://www.modelscope.cn/<owner_name>/<model-name>.git
# 例如: git clone https://www.modelscope.cn/iic/ofa_image-caption_coco_large_en.git# 私有模型下载,前提是您有响应模型权限 方法1
git lfs install
git clone http://oauth2:your_git_token@www.modelscope.cn/<owner_name>/<model-name>.git
# 方法2
git clone http://your_user_name@www.modelscope.cn/<owner_name>/<model-name>.git
# Password for 'http://your_user_name@modelscope.cn':
# input git token
# model
https://www.modelscope.cn/models/Qwen/Qwen2-7b/files

http://www.mrgr.cn/news/98590.html

相关文章:

  • 基于AD9767高速DAC的DDS信号发生器
  • C++学习之工厂模式-套接字通信
  • 【本地图床搭建】宝塔+Docker+MinIO+PicGo+cpolar:打造本地化“黑科技”图床方案
  • IDEA202403 常用设置【持续更新】
  • LWIP学习笔记
  • Android studio打包uniapp插件
  • 【NLP 59、大模型应用 —— 字节对编码 bpe 算法】
  • 使用 Vitis Model Composer 生成 FPGA IP 核
  • 【QT】 QT定时器的使用
  • 常见的 14 个 HTTP 状态码详解
  • 如何在 Windows 安卓子系统 (WSA) 上安装小红书应用
  • rce漏洞学习
  • Ubuntu2404装机指南
  • 【Docker-13】Docker Container容器
  • 虚幻基础:碰撞帧运算
  • UWB定位技术面临的主要挑战
  • go中我遇到的问题总结
  • Redis 分布式锁+秒杀异步优化
  • Git 学习笔记
  • 鸿蒙系统开发状态更新字段区别对比