用infoNCE微调Embedding模型
infoNCE
代码1:(样本格式为query_n个positive_n个hardnegative)
- PairwiseModel并不是模型,而是连接model和loss的一个包装类。
-
PairwiseModel接收两种类型样本 【query + pos pair】or【query + pos + neg triplet】。
- CrossEntropyLoss还可以传入label_smoothing=0.05,用于对比学习。label_smoothing = 0.3时,label_smoothing 的作用是把硬标签 [0, 0, 1, 0] 平滑成类似 [0.1, 0.1, 0.7, 0.1],从而使得 CrossEntropyLoss 不再只惩罚预测不对的类,还会对非目标类的概率也做约束,使模型更加平滑稳定、泛化更强。
- AutoModelForEmbedding的pooling_method选择mean还是cls根据模型来定,如果模型训练的时候用cls向量当做句子表征,则用cls。否则则用mean。
代码2:(样本格式为query_positive,只有正样本,负样本为batch内其他样本)
import os
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
model_name_or_path: str = '../model/m3e-base'
# model_name_or_path: str = '../model/bge-small-zh-v1.5'
batch_size: int = 2
epochs: int = 3
#数据集会按照dev、train、test划分。具体有哪个,得print来看,再用split="dev"获取 dev的部分。
train_dataset = load_dataset("../../dataset/C-MTEB/T2Reranking", split="dev") #这个数据集并不是 query_positive格式,而是query_n个positive,因此需要更改
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
train_model = PairwiseModel(model, loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
optimizer = AdamW(train_model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_arguments = TrainingArguments(output_dir='./checkpoints',num_train_epochs=epochs,per_device_train_batch_size=batch_size,remove_unused_columns=False,logging_steps=50,
)
# 处理后会得到一个两个key的dict,每个value是一个包含dict(其中包含input_ids、token_type_ids、attention_mask)
dc=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[64, 128])
trainer = RetrievalTrainer(model=train_model,args=training_arguments,train_dataset=train_dataset,data_collator=dc, # 相当于 自定义collate_fn 函数
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()
参考:
动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习_m3e模型微调-CSDN博客
https://github.com/LongxingTan/open-retrievals?tab=readme-ov-file