【LLM论文日更】| GRIT如何统一文本生成与嵌入
- 论文:https://arxiv.org/pdf/2404.05961
- 代码:https://github.com/McGill-NLP/llm2vec
- 机构:McGill University, Mila ServiceNow Research ,Facebook CIFAR AI Chair
- 领域:embedding model
- 发表:COLM 2024
研究背景
所有基于文本的语言问题都可以简化为生成或嵌入。当前的模型仅在其中之一上表现良,本文引入了生成表征指令调整(GRIT),通过指令来训练大型语言模型来处理生成任务和嵌入任务。与其他开放模型相比,本文生成的 GRITLM 7B 在大规模文本嵌入基准 (MTEB) 上树立了新的技术水平,并且在一系列生成任务上优于同等规模的所有模型。通过进一步扩展,GRITLM 8X7B 的性能优于本文尝试过的所有开放生成语言模型,同时仍然是最好的嵌入模型之一。值得注意的是,我们发现 GRIT 仅匹配生成数据或嵌入数据的训练,因此我们可以在不损失性能的情况下统一两者。而且通过 GRIT 进行统一,不再需要单独的检索和生成模型,可以将长文档的检索增强生成 (RAG) 速度提高 60% 以上。
研究方法
这篇论文提出了生成表示指令调优(GRIT)方法,用于解决生成任务和嵌入任务的统一问题。具体来说,
-
生成表示指令调优(GRIT):GRIT通过指令区分生成任务和嵌入任务。对于嵌入任务,指令包含目标域、意图和单位,表示为一个数值张量;对于生成任务,指令生成文本输出。
损失函数:对于嵌入数据,使用对比损失函数:
其中,f是GRITLM参数化的模型,τ 是温度超参数,σ 对每个输出进行池化后应用余弦相似度,q和 d 分别是查询和文档样本。
3. 生成数据损失:对于生成数据,使用语言建模损失函数:
其中,f 是GRITLM模型,η 是语言建模头,x 是生成训练样本。
4. 总损失函数:将生成和嵌入损失函数通过可选的损失权重λRep和λGen进行聚合。
实验设计
- 数据集:实验使用了Mistral 7B和Mixtral 8x7B模型,并从E5和Tilu 2数据集中进行微调。E5数据集通过添加S2ORC数据集进行扩展,Tilu 2数据集过滤掉了包含模型答案的自定义提示。
- 批大小和训练步数:对于GRITLM 7B,嵌入数据的批大小为2048,生成数据的批大小为256,总共训练1253步。对于GRITLM 8x7B,由于计算限制,嵌入数据的批大小减少到256。
- 超参数:学习率为2e-5,使用3%的步骤进行线性预热,然后线性衰减到0。使用PyTorch FSDP、梯度检查点、BF16混合精度训练等策略来节省内存。
结果与分析
嵌入模型性能:
生成性能:
RAG联动
不足与反思
- 计算成本:由于使用两个目标函数进行训练,GRIT需要更多的计算资源。然而,由于微调比预训练便宜得多,作者认为这些成本远远被其带来的好处所抵消。
- 多语言支持:尽管GRITLM在非英语任务上也有不错的表现,但主要的性能提升可能来自于数据集和架构的变化。
- 多模态任务:许多嵌入和生成问题不仅仅是基于文本的,如图像和文本的联合嵌入、生成图像字幕等。探索这些任务是否可以像文本嵌入和生成一样容易地统一仍需进一步研究。
- 优化RAG:未来的工作可以考虑优化GRITLM以更好地与检索系统集成,例如通过让生成模型在其认为必要时自行发起搜索。
- 预训练:实验使用了现成的预训练语言模型,但也可以考虑使用GRIT方法从头开始进行预训练。
- 格式效率:当前的格式效率不高,未来可以通过使用特殊标记来简化编码过程,从而降低训练和推理的成本。
- 打包策略:未来工作可以探索在训练期间对嵌入样本进行打包,甚至将生成和嵌入训练数据打包到同一个样本中,以提高效率。