当前位置: 首页 > news >正文

用infoNCE微调Embedding模型

infoNCE

代码1:(样本格式为query_n个positive_n个hardnegative)

  • PairwiseModel并不是模型,而是连接model和loss的一个包装类。
  • PairwiseModel接收两种类型样本 【query + pos pair】or【query + pos + neg triplet】。

  • CrossEntropyLoss还可以传入label_smoothing=0.05,用于对比学习。label_smoothing = 0.3时,label_smoothing 的作用是把硬标签 [0, 0, 1, 0] 平滑成类似 [0.1, 0.1, 0.7, 0.1],从而使得 CrossEntropyLoss 不再只惩罚预测不对的类,还会对非目标类的概率也做约束,使模型更加平滑稳定、泛化更强。
  • AutoModelForEmbedding的pooling_method选择mean还是cls根据模型来定,如果模型训练的时候用cls向量当做句子表征,则用cls。否则则用mean。

代码2:(样本格式为query_positive,只有正样本,负样本为batch内其他样本)

import os
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
model_name_or_path: str = '../model/m3e-base'
# model_name_or_path: str = '../model/bge-small-zh-v1.5'
batch_size: int = 2
epochs: int = 3
#数据集会按照dev、train、test划分。具体有哪个,得print来看,再用split="dev"获取 dev的部分。
train_dataset = load_dataset("../../dataset/C-MTEB/T2Reranking", split="dev") #这个数据集并不是 query_positive格式,而是query_n个positive,因此需要更改
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
train_model = PairwiseModel(model, loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
optimizer = AdamW(train_model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_arguments = TrainingArguments(output_dir='./checkpoints',num_train_epochs=epochs,per_device_train_batch_size=batch_size,remove_unused_columns=False,logging_steps=50,
)
# 处理后会得到一个两个key的dict,每个value是一个包含dict(其中包含input_ids、token_type_ids、attention_mask)
dc=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[64, 128]) 
trainer = RetrievalTrainer(model=train_model,args=training_arguments,train_dataset=train_dataset,data_collator=dc, # 相当于 自定义collate_fn 函数
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

参考:

动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习_m3e模型微调-CSDN博客

https://github.com/LongxingTan/open-retrievals?tab=readme-ov-file


http://www.mrgr.cn/news/98111.html

相关文章:

  • 十四种逻辑器件综合对比——《器件手册--逻辑器件》
  • qt pyqt5的开发, 修改psd图像
  • 七种数码管驱动/LED驱动综合对比——《器件手册--数码管驱动/LED驱动》
  • ubuntu 2204 安装 vcs 2018
  • (六)深入了解AVFoundation-播放:AirPlay、画中画后台播放
  • 【Flink运行时架构】核心组件
  • AI编程案例拆解|基于机器学习XX评分系统-前端篇
  • 汇编获取二进制
  • Linux基础14
  • 解决2080Ti使用节点ComfyUI-PuLID-Flux-Enhanced中遇到的问题
  • 2019年计算机真题
  • 小刚说C语言刷题——第22讲 二维数组
  • 【学习笔记】两个类之间的数据交互方式
  • 可配置多功能门芯片的12种用法推导——基于74LVC1G97芯片(附1G98、1G57、1G58、1G99用法)
  • 470用 Rand7() 实现 Rand10()
  • leetcode572 另一棵树的子树
  • 每天学一个 Linux 命令(14):cat
  • Linux进程概念
  • 【MQTT-协议原理】
  • 2025蓝桥杯算法竞赛深度突破:创新题型与高阶策略全解析