当前位置: 首页 > news >正文

DuoAttention:高效处理长上下文推理的 AI 框架,让 LLMs 如虎添翼!

❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦


🚀 快速阅读

  1. DuoAttention 通过区分“检索头”和“流式头”两种注意力头,优化模型的内存使用和计算速度。
  2. DuoAttention 能在保持模型准确性的同时,减少内存消耗和提高解码及预填充的速度。
  3. 结合量化技术,DuoAttention 能在单个 GPU 上实现高达 330 万 token 的上下文推理。

正文(附运行示例)

DuoAttention 是什么

在这里插入图片描述

DuoAttention 是新型的框架,由 MIT 韩松团队提出,用在提高大型语言模型(LLMs)在处理长上下文时的推理效率。基于区分“检索头”和“流式头”两种注意力头,优化模型的内存使用和计算速度。检索头负责处理长距离依赖,需要完整的键值(KV)缓存,流式头关注最近 token 和注意力汇聚点,只需固定长度的 KV 缓存。两种注意力头让 DuoAttention 在保持模型准确性的同时,减少内存消耗和提高解码及预填充的速度。结合量化技术,DuoAttention 能在单个 GPU 上实现高达 330 万 token 的上下文推理,是处理长文本信息的有效方案。

DuoAttention 的主要功能

  • 提高长上下文推理效率:基于优化大型语言模型(LLMs)的注意力机制,DuoAttention 显著提升模型处理长上下文数据的能力。
  • 减少内存消耗:区分需要完整 KV 缓存的检索头和只需固定长度 KV 缓存的流式头,减少模型运行时的内存占用。
  • 加速解码和预填充过程:DuoAttention 优化模型的解码速度和预填充(Pre-filling)速度,提高 LLMs 的响应时间和处理效率至关重要。
  • 保持模型准确性:在减少内存消耗和提高效率的同时,DuoAttention 能保持模型在处理长短上下文任务时的准确性。

DuoAttention 的技术原理

  • 注意力头的区分:DuoAttention 将 LLMs 中的注意力头分为检索头和流式头。检索头负责捕捉上下文中的关键信息,对所有 token 进行完整注意力处理;流式头主要处理近期 token 和注意力汇聚点,不需要存储全部历史 KV 状态。
  • 检索头的 KV 缓存优化:为检索头保留完整的 KV 缓存,确保能捕捉到长距离依赖信息。
  • 流式头的轻量级 KV 缓存:流式头用固定长度的 KV 缓存,减少对内存的需求,支持模型高效处理长序列数据。
  • 检索头的自动识别:DuoAttention 用基于优化的算法和合成数据集训练模型,自动识别出哪些头是检索头,在推理时为分配适当的 KV 缓存策略。
  • 合成数据集:设计合成数据集和密码召回任务,DuoAttention 能确定哪些注意力头在保留或丢弃 KV 缓存后对模型输出有显著影响,优化模型的长上下文处理能力。

如何运行 DuoAttention

环境设置

训练和评估环境
conda create -yn duo python=3.10
conda activate duoconda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidiapip install transformers accelerate sentencepiece datasets wandb accelerate sentencepiece datasets wandb zstandard matplotlib huggingface_hub
pip install tensor_parallelpip install ninja packaging
pip install flash-attn --no-build-isolation# LongBench评估
pip install seaborn rouge_score einops pandaspip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/# 安装DuoAttention
pip install -e .# 安装Block Sparse Streaming Attention
git clone git@github.com:mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
python setup.py install
演示环境
conda create -yn duo_demo python=3.10
conda activate duo_demo# 安装DuoAttention
pip install -e .conda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev# 安装QServe
git clone git@github.com:mit-han-lab/qserve.git
cd qserve
pip install -e .
pip install ninja packaging
pip install flash-attn==2.4.1 --no-build-isolation
cd kernels
python setup.py install# 安装FlashInfer
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
pip install tensor_parallel

数据集

下载数据集:

mkdir -p datasets
cd datasetswget https://huggingface.co/datasets/togethercomputer/Long-Data-Collections/resolve/main/fine-tune/booksum.jsonl.zst

模型

下载 DuoAttention 支持的模型:

mkdir -p models
cd models# DuoAttention目前支持的评估模型
huggingface-cli download togethercomputer/Llama-2-7B-32K-Instruct --local-dir Llama-2-7B-32K-Instruct
huggingface-cli download gradientai/Llama-3-8B-Instruct-Gradient-1048k --local-dir Llama-3-8B-Instruct-Gradient-1048k
huggingface-cli download gradientai/Llama-3-8B-Instruct-Gradient-4194k --local-dir Llama-3-8B-Instruct-Gradient-4194k
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2 --local-dir Mistral-7B-Instruct-v0.2
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.3 --local-dir Mistral-7B-Instruct-v0.3# 使用SmoothQuant和QServe的W8A8KV4模型
huggingface-cli download mit-han-lab/Llama-3-8B-Instruct-Gradient-1048k-w8a8kv4-per-channel --local-dir Llama-3-8B-Instruct-Gradient-1048k-w8a8kv4-per-channel
huggingface-cli download mit-han-lab/Llama-3-8B-Instruct-Gradient-4194k-w8a8kv4-per-channel --local-dir Llama-3-8B-Instruct-Gradient-4194k-w8a8kv4-per-channel

快速开始 DuoAttention

我们提供了一个简单的单点击 patch,用于在 HuggingFace 模型上启用 DuoAttention 优化,包括 Llama 和 Mistral。attn_patterns目录中提供了五个长上下文模型的预训练检索头模式:Llama-2-7B-32K-InstructLlama-3-8B-Instruct-Gradient-1048kLlama-3-8B-Instruct-Gradient-4194kMistral-7B-Instruct-v0.2Mistral-7B-Instruct-v0.3Meta-Llama-3.1-8B-Instruct。如果您想训练自己的检索头模式,可以使用 scripts 目录中提供的训练脚本。以下是如何在Llama-3-8B-Instruct-Gradient-1048k模型上启用 DuoAttention 的示例。

from duo_attn.utils import load_attn_pattern, sparsify_attention_heads
from duo_attn.patch import enable_duo_attention_eval
import transformers
import torch# 加载模型
model = transformers.AutoModelForCausalLM.from_pretrained("models/Llama-3-8B-Instruct-Gradient-1048k",torch_dtype=torch.bfloat16,low_cpu_mem_usage=True,attn_implementation="eager",
)# 加载注意力模式
attn_heads, sink_size, recent_size = load_attn_pattern("attn_patterns/Llama-3-8B-Instruct-Gradient-1048k/lr=0.02-reg=0.05-ctx=1000_32000-multi_passkey10"
)# 稀疏化注意力头
attn_heads, sparsity = sparsify_attention_heads(attn_heads, sparsity=0.5)# 启用DuoAttention
enable_duo_attention_eval(model,attn_heads,sink_size=64,recent_size=256,
)# 将模型移至GPU
model = model.cuda()# 准备进行推理!

演示

设置环境后,您可以运行以下脚本以在Llama-3-8B-Instruct-Gradient-4194k模型上执行 W4A8KV4 与 DuoAttention 的演示。该演示旨在在单个 A100 GPU 上运行,并支持高达 330 万个 token 的上下文长度。

bash scripts/run_demo.sh

资源

  • GitHub 仓库:https://github.com/mit-han-lab/duo-attention
  • arXiv 技术论文:https://arxiv.org/pdf/2410.10819

❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦


http://www.mrgr.cn/news/58566.html

相关文章:

  • [计算机网络]第一周
  • 数据结构与算法——Java实现 47. 从中序与后序遍历序列构造二叉树
  • 优先算法——移动零(双指针)
  • Atlassian Team ‘24 Europe:推出Rovo、开发人员AI助手、新版Jira等多款AI创新,重塑团队协作
  • 【制造业&仓库】快递盒纸箱检测系统源码&数据集全套:改进yolo11-LAWDS
  • Linux的目录结构 常用基础命令(2)
  • vi编辑器
  • MySQL查看某个数据库里面每张表的字符集和字符排序集
  • 江协科技STM32学习- P21 ADC模数转换器
  • Isaac Sim Docker 部署并使用过程记录
  • 【数据结构和算法】二、python中的常用数据结构(数组、链表、堆栈、递归、二叉树、哈夫曼树等数据结构的基本原理讲解与实战演练)
  • 尼日利亚CRIA解析
  • c++实现boost搜索引擎功能扩展 介绍+代码(日志,处理暂停词,增加数据源,引入广告竞价,增加用户管理,连接mysql)
  • Nestjs请求处理顺序
  • 【信息系统管理工程师】与【信息系统项目管理师】傻傻分不清楚?一文说清楚
  • 谷歌开发者账号,为什么新号老是因为高风险被封?
  • 如何将原本打开Edge呈现出的360浏览器,更换成原本的Edge页面或者百度等其他页面
  • uniapp开发Web页面之动态菜单配置攻略
  • LEG引擎装备升级脚本,BLUE引擎传奇添加升级装备的NPC示例
  • 卷积神经网络评价指标
  • 客服的沟通技巧与策略
  • Sei 生态迎首个 MMORPG 游戏伙伴 Final Glory,开启新篇章
  • [Java进阶] 并发编程之进程、线程和协程
  • 23种设计模式
  • Vue3 + TypeScript 实现 iframe 嵌入与通信的完整指南以及全屏弹窗方案
  • 动态规划-子序列问题——376.摆动序列