DuoAttention:高效处理长上下文推理的 AI 框架,让 LLMs 如虎添翼!
❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦!
🥦 微信公众号|搜一搜:蚝油菜花 🥦
🚀 快速阅读
- DuoAttention 通过区分“检索头”和“流式头”两种注意力头,优化模型的内存使用和计算速度。
- DuoAttention 能在保持模型准确性的同时,减少内存消耗和提高解码及预填充的速度。
- 结合量化技术,DuoAttention 能在单个 GPU 上实现高达 330 万 token 的上下文推理。
正文(附运行示例)
DuoAttention 是什么
DuoAttention 是新型的框架,由 MIT 韩松团队提出,用在提高大型语言模型(LLMs)在处理长上下文时的推理效率。基于区分“检索头”和“流式头”两种注意力头,优化模型的内存使用和计算速度。检索头负责处理长距离依赖,需要完整的键值(KV)缓存,流式头关注最近 token 和注意力汇聚点,只需固定长度的 KV 缓存。两种注意力头让 DuoAttention 在保持模型准确性的同时,减少内存消耗和提高解码及预填充的速度。结合量化技术,DuoAttention 能在单个 GPU 上实现高达 330 万 token 的上下文推理,是处理长文本信息的有效方案。
DuoAttention 的主要功能
- 提高长上下文推理效率:基于优化大型语言模型(LLMs)的注意力机制,DuoAttention 显著提升模型处理长上下文数据的能力。
- 减少内存消耗:区分需要完整 KV 缓存的检索头和只需固定长度 KV 缓存的流式头,减少模型运行时的内存占用。
- 加速解码和预填充过程:DuoAttention 优化模型的解码速度和预填充(Pre-filling)速度,提高 LLMs 的响应时间和处理效率至关重要。
- 保持模型准确性:在减少内存消耗和提高效率的同时,DuoAttention 能保持模型在处理长短上下文任务时的准确性。
DuoAttention 的技术原理
- 注意力头的区分:DuoAttention 将 LLMs 中的注意力头分为检索头和流式头。检索头负责捕捉上下文中的关键信息,对所有 token 进行完整注意力处理;流式头主要处理近期 token 和注意力汇聚点,不需要存储全部历史 KV 状态。
- 检索头的 KV 缓存优化:为检索头保留完整的 KV 缓存,确保能捕捉到长距离依赖信息。
- 流式头的轻量级 KV 缓存:流式头用固定长度的 KV 缓存,减少对内存的需求,支持模型高效处理长序列数据。
- 检索头的自动识别:DuoAttention 用基于优化的算法和合成数据集训练模型,自动识别出哪些头是检索头,在推理时为分配适当的 KV 缓存策略。
- 合成数据集:设计合成数据集和密码召回任务,DuoAttention 能确定哪些注意力头在保留或丢弃 KV 缓存后对模型输出有显著影响,优化模型的长上下文处理能力。
如何运行 DuoAttention
环境设置
训练和评估环境
conda create -yn duo python=3.10
conda activate duoconda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidiapip install transformers accelerate sentencepiece datasets wandb accelerate sentencepiece datasets wandb zstandard matplotlib huggingface_hub
pip install tensor_parallelpip install ninja packaging
pip install flash-attn --no-build-isolation# LongBench评估
pip install seaborn rouge_score einops pandaspip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/# 安装DuoAttention
pip install -e .# 安装Block Sparse Streaming Attention
git clone git@github.com:mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
python setup.py install
演示环境
conda create -yn duo_demo python=3.10
conda activate duo_demo# 安装DuoAttention
pip install -e .conda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev# 安装QServe
git clone git@github.com:mit-han-lab/qserve.git
cd qserve
pip install -e .
pip install ninja packaging
pip install flash-attn==2.4.1 --no-build-isolation
cd kernels
python setup.py install# 安装FlashInfer
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
pip install tensor_parallel
数据集
下载数据集:
mkdir -p datasets
cd datasetswget https://huggingface.co/datasets/togethercomputer/Long-Data-Collections/resolve/main/fine-tune/booksum.jsonl.zst
模型
下载 DuoAttention 支持的模型:
mkdir -p models
cd models# DuoAttention目前支持的评估模型
huggingface-cli download togethercomputer/Llama-2-7B-32K-Instruct --local-dir Llama-2-7B-32K-Instruct
huggingface-cli download gradientai/Llama-3-8B-Instruct-Gradient-1048k --local-dir Llama-3-8B-Instruct-Gradient-1048k
huggingface-cli download gradientai/Llama-3-8B-Instruct-Gradient-4194k --local-dir Llama-3-8B-Instruct-Gradient-4194k
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2 --local-dir Mistral-7B-Instruct-v0.2
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.3 --local-dir Mistral-7B-Instruct-v0.3# 使用SmoothQuant和QServe的W8A8KV4模型
huggingface-cli download mit-han-lab/Llama-3-8B-Instruct-Gradient-1048k-w8a8kv4-per-channel --local-dir Llama-3-8B-Instruct-Gradient-1048k-w8a8kv4-per-channel
huggingface-cli download mit-han-lab/Llama-3-8B-Instruct-Gradient-4194k-w8a8kv4-per-channel --local-dir Llama-3-8B-Instruct-Gradient-4194k-w8a8kv4-per-channel
快速开始 DuoAttention
我们提供了一个简单的单点击 patch,用于在 HuggingFace 模型上启用 DuoAttention 优化,包括 Llama 和 Mistral。attn_patterns
目录中提供了五个长上下文模型的预训练检索头模式:Llama-2-7B-32K-Instruct
、Llama-3-8B-Instruct-Gradient-1048k
、Llama-3-8B-Instruct-Gradient-4194k
、Mistral-7B-Instruct-v0.2
、Mistral-7B-Instruct-v0.3
和Meta-Llama-3.1-8B-Instruct
。如果您想训练自己的检索头模式,可以使用 scripts 目录中提供的训练脚本。以下是如何在Llama-3-8B-Instruct-Gradient-1048k
模型上启用 DuoAttention 的示例。
from duo_attn.utils import load_attn_pattern, sparsify_attention_heads
from duo_attn.patch import enable_duo_attention_eval
import transformers
import torch# 加载模型
model = transformers.AutoModelForCausalLM.from_pretrained("models/Llama-3-8B-Instruct-Gradient-1048k",torch_dtype=torch.bfloat16,low_cpu_mem_usage=True,attn_implementation="eager",
)# 加载注意力模式
attn_heads, sink_size, recent_size = load_attn_pattern("attn_patterns/Llama-3-8B-Instruct-Gradient-1048k/lr=0.02-reg=0.05-ctx=1000_32000-multi_passkey10"
)# 稀疏化注意力头
attn_heads, sparsity = sparsify_attention_heads(attn_heads, sparsity=0.5)# 启用DuoAttention
enable_duo_attention_eval(model,attn_heads,sink_size=64,recent_size=256,
)# 将模型移至GPU
model = model.cuda()# 准备进行推理!
演示
设置环境后,您可以运行以下脚本以在Llama-3-8B-Instruct-Gradient-4194k
模型上执行 W4A8KV4 与 DuoAttention 的演示。该演示旨在在单个 A100 GPU 上运行,并支持高达 330 万个 token 的上下文长度。
bash scripts/run_demo.sh
资源
- GitHub 仓库:https://github.com/mit-han-lab/duo-attention
- arXiv 技术论文:https://arxiv.org/pdf/2410.10819
❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦!
🥦 微信公众号|搜一搜:蚝油菜花 🥦