当前位置: 首页 > news >正文

LLAMAFACTORY:一键优化大型语言模型微调的利器

人工智能咨询培训老师叶梓 转载标明出处

模型适配到特定的下游任务,通常需要进行微调(fine-tuning),这一过程往往需要大量的计算资源。为了解决这一问题,来自北京航空航天大学和北京大学的研究人员共同开发了LLAMAFACTORY,这是一个统一的框架,集成了多种前沿的高效训练方法,使得用户可以灵活地自定义100多种大型语言模型的微调过程,而无需编写代码。表格1列出了LLAMAFACTORY框架中支持的高效微调技术。表格2展示了LLAMAFACTORY支持的数据集结构。

LLAMAFACTORY具有以下几个关键特性:

  1. 模型加载器(Model Loader):支持超过100种预训练模型,能够自动识别模型层次结构并附加适配器。
  2. 数据工作者(Data Worker):处理来自不同任务的数据,支持50多个数据集,通过设计良好的数据处理管道,将不同格式的数据集标准化为统一格式。
  3. 训练器(Trainer):集成了多种高效的微调方法,如LoRA+、GaLore等,支持分布式训练,进一步降低内存消耗。
  4. LLAMABOARD:提供了一个友好的可视化界面,用户可以通过Web界面配置和启动个别LLM微调过程,并实时监控训练状态。

想要掌握如何将大模型的力量发挥到极致吗?2024年10月26日(今晚8点)叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。

留言“参加”即可来叶老师的直播间互动,1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。互动交流,畅谈工作中遇到的实际问题。

LLAMAFACTORY框架

如图1所示,LLAMAFACTORY框架由三个主要模块组成:Model Loader、Data Worker和Trainer,以及一个用户友好的可视化界面LLAMABOARD。这些模块共同工作,实现了从数据加载到模型微调的全过程,而LLAMABOARD则提供了一个无需编码即可配置和启动大模型微调实例的界面,并能同步监控训练状态。

Model Loader负责操纵不同架构的模型以进行微调,支持包括视觉语言模型(VLMs)在内的大模型。这一模块包含四个主要部分:模型初始化、模型修补、模型量化和适配器附加。

  • 模型初始化:使用Transformers的Auto Classes来加载预训练模型并初始化参数。对于视觉语言模型,使用AutoModelForVision2Seq类进行加载,而其他模型则使用AutoModelForCausalLM类。若分词器的词汇量超过了嵌入层的容量,LLAMAFACTORY会调整层大小并用带噪声的均值初始化新参数。
  • 模型修补:为了启用S2注意力,使用了monkey patch来替换模型的前向计算。而对于flash attention,则使用了自Transformers 4.34.0起广泛支持的原生类。
  • 模型量化:可以通过bitsandbytes库将模型动态量化到8位或4位。对于4位量化,使用了双量化和4位浮点数作为QLoRA的方法。
  • 适配器附加:自动识别合适的层以附加适配器,以改善模型的收敛性。PEFT库提供了一种非常方便的方式来实现适配器方法,如LoRA、rsLoRA、DoRA和PiSSA。此外,为了执行基于人类反馈的强化学习(RLHF),在transformer模型的顶部附加了一个值头层,将每个token的表示映射为一个标量。

Data Worker开发了一个数据处理管道,包括数据集加载、对齐、合并和预处理,将不同任务的数据集标准化为统一格式。

  • 数据集加载:使用datasets库来加载数据,允许用户从Hugging Face Hub加载远程数据集或通过脚本或文件读取本地数据集。
  • 数据集对齐:设计了数据描述规范来表征数据集的结构,如alpaca数据集的三个列:instruction、input和output。根据数据描述规范将数据集转换为标准结构,以兼容各种任务。
  • 数据集合并:统一的数据集结构为合并多个数据集提供了高效的方法。在非流式模式下,数据集会在训练期间被简单连接然后打乱。在流式模式下,提供了交替读取不同数据集的方法。
  • 数据集预处理:LLAMAFACTORY旨在微调文本生成模型,主要用于聊天补全。聊天模板是这些模型中的关键组成部分,因为它与模型的指令跟踪能力高度相关。因此,提供了数十种聊天模板,可以根据模型类型自动选择。应用聊天模板后,使用分词器对句子进行编码。默认情况下,只计算补全部分的损失,而忽略提示。

Trainer集成了最先进的高效微调方法,如LoRA+、GaLore和BAdam,通过替换默认组件来实现。这些微调方法独立于Trainer,易于应用于各种任务。此外,LLAMAFACTORY还提出了模型共享RLHF方法,允许在不超过一个预训练模型的情况下进行整个RLHF训练。在分布式训练方面,可以与DeepSpeed结合使用,通过数据并行性充分利用计算设备的能力,并利用DeepSpeed ZeRO优化器进一步通过分区或卸载减少内存消耗。

LLAMABOARD是一个基于Gradio的统一用户界面,允许用户自定义大模型的微调而无需编写任何代码。它提供了简化的模型微调和推理服务,使用户能够轻松探索大模型在各自环境中的潜力。LLAMABOARD具有以下特点:

  • 轻松配置:允许用户通过与Web界面的交互来自定义微调参数,并为大多数参数提供默认值,简化了配置过程。
  • 可监控的训练:训练过程中,训练日志和损失曲线被可视化并实时更新,使用户能够监控训练进度。
  • 灵活评估:支持计算数据集上的文本相似度分数以自动评估模型,或通过与模型聊天进行人工评估。
  • 多语言支持:提供本地化文件,方便集成新语言以渲染界面。目前支持英语、俄语和中文,允许更广泛的用户使用LLAMABOARD进行大模型的微调。

LLAMAFACTORY通过其模块化设计和集成的高效微调技术,为大模型的微调提供了一个强大的平台,使得用户可以轻松地对大模型进行微调,以适应各种下游任务。此外,LLAMAFACTORY还提供了灵活的评估和多语言支持,使其成为一个真正全球化的工具,可以广泛应用于不同的语言和文化环境中。

实证研究

对LLAMAFACTORY的评估从两个角度进行:一是训练效率,包括内存使用、吞吐量和困惑度;二是对下游任务的适应能力。

训练效率

实验使用了PubMed数据集,该数据集包含超过3600万条生物医学文献记录。从文献摘要中提取了约40万个token来构建训练语料库。随后,使用生成式预训练目标和各种高效的微调方法对Gemma-2B、Llama2-7B和Llama2-13B模型进行了微调。比较了全参数微调、冻结微调、GaLore、LoRA和4位QLoRA的结果。微调后,计算了在训练语料库上的困惑度,以评估不同方法的效率。同时,也将预训练模型的困惑度作为基线纳入比较。

训练效率的结果在表4中展示,其中内存指的是训练期间消耗的峰值内存,吞吐量以每秒训练的token数计算,PPL代表模型在训练语料库上的困惑度。由于全参数微调Llama2-13B导致内存溢出,因此没有记录结果。可以观察到,由于预训练权重以较低精度表示,QLoRA始终具有最低的内存占用。利用Unsloth优化的LoRA层,LoRA展示了更高的吞吐量。对于大型模型,GaLore实现了更低的PPL,而LoRA在小型模型上表现更优。

下游任务的微调

为了评估不同高效微调方法的有效性,实验比较了各种模型在下游任务上微调后的性能。使用CNN/DM、XSum和AdGen三个代表性文本生成任务的2000个样本和1000个样本分别构建了不重叠的训练集和测试集。选择了几个指令微调模型,并使用不同的微调方法进行了序列到序列任务的微调。然后,比较了全参数微调(FT)、GaLore、LoRA和4位QLoRA的结果。微调后,计算了每个任务测试集上的ROUGE分数。同时,也将原始指令微调模型的分数作为基线纳入比较。

下游任务的评估结果在表5中展示。报告了ROUGE-1、ROUGE-2和ROUGEL的平均分数。由于GaLore方法可能不适用于Gemma-7B和Qwen2-7B模型,因此这些模型的一些结果没有包括在表中。一个有趣的发现是,LoRA和QLoRA在大多数情况下实现了最佳性能,除了ChatGLM3-6B和Llama2-7B模型在CNN/DM和AdGen数据集上的表现。这一现象突出了这些高效微调方法在使大模型适应特定任务方面的有效性。此外,观察到Llama3-8B在这些模型中实现了最佳性能,而Yi-6B和Mistral-7B在同等大小的模型中展现了有竞争力的性能。

表4和表5中的结果清楚地展示了LLAMAFACTORY框架在不同微调方法下的性能对比,证明了其在训练效率和下游任务适应性方面的有效性。这些实验不仅验证了LLAMAFACTORY的技术实力,也为未来大模型的微调提供了宝贵的参考数据。

论文链接:https://arxiv.org/abs/2403.13372
Github链接:https://github.com/hiyouga/LLaMA-Factory


http://www.mrgr.cn/news/59007.html

相关文章:

  • 中电信翼康工程师:我在 Apache SeaTunnel 社区的贡献之旅
  • 【WRF数据准备】基于GEE下载静态地理数据-LULC和ISA
  • Java--多态
  • 命名空间std, using namespace std
  • 微服务网关Zuul
  • 手机摄影入门
  • [旧日谈]高清画面撕裂问题考
  • 解决Redis缓存穿透(缓存空对象、布隆过滤器)
  • React中的hook
  • Bat 案例 -- 注册入站端口
  • PD诱骗取电快充协议,一款可额外定制功能的快充协议芯片
  • 119.WEB渗透测试-信息收集-ARL(10)
  • HT7181 16.8V,14A高效升压转换器
  • linux中myshell的实现
  • 长短期记忆网络(LSTM)详解
  • unity游戏开发之塔防游戏
  • 词云图大师支持词云图字体预览,轻松选择字体样式!
  • list 的实现
  • SQL语句的书写顺序与实际执行顺序的差异,以及如何利用执行顺序优化查询性能
  • SpringBoot中EasyExcel使用实践总结
  • 【Java】java 集合框架(详解)
  • 电脑连接海康相机并在PictureBox和HWindowControl中分别显示。
  • 开源数据库 - mysql - 组织结构(与oracle的区别)
  • 系统调用的介绍
  • 每日“亿“题 东方博宜OJ 1538 - 小 X 与煎饼达人(flip)
  • 线程安全介绍