当前位置: 首页 > news >正文

【Conan-embedding模型排名第一的embedding中文模型】

在这里插入图片描述

文章目录

  • 1、Conan-embedding模型
  • 2、什么是动态难负样本挖掘

1、Conan-embedding模型

Conan-embedding模型,该模型通过最大化利用更多、更高质量的负样本来增强嵌入模型的能力,在大规模文本嵌入基准(MTEB)的中文排行榜上排名第一。

  1. 研究背景
    • 随着RAG的流行,嵌入模型的能力受到关注,其主要通过对比学习训练,负样本是关键,但现有硬负挖掘策略多作为预处理步骤,存在局限。
  2. Conan-Embedding模型方法
    • 训练工作流程
      • 预训练:采用多阶段训练,预训练阶段使用标准数据过滤方法,用bge - large - zh - v1.5模型评分并丢弃低分数据,利用InfoNCE loss with In - Batch Negative进行训练。
      • 监督微调:分为检索和STS任务,检索任务用InfoNCE loss,STS任务采用CoSENT loss。
    • 动态难负样本挖掘:训练中迭代挖掘硬负样本,根据硬负样本相对查询的平均得分判断是否重新挖掘。
    • 跨GPU批次平衡损失(CBB):平衡不同任务负样本数量,在每个训练周期以平衡方式引入任务,计算CBB Loss。
  3. 实验结果
    • 实现细节:基于BERT large模型,采用线性层扩展维度,运用MRL技术,设置不同阶段的优化器、学习率等参数,使用多种GPU进行训练。
    • 数据集:预训练收集多种文本数据对及LLM生成的数据,微调选择检索、分类和STS数据集。
    • CMTEB结果:在CMTEB基准测试中超越现有模型。
    • 消融研究:动态硬负挖掘和CBB Loss显著优于直接微调的方法。
    • 分析:CBB Loss使损失平滑下降,优于单独训练。
  4. 结论
    • 提出Conan-embedding模型,通过动态难负挖掘和跨GPU平衡损失提升嵌入模型性能,模型已上传至Huggingface。

2、什么是动态难负样本挖掘

详细解释动态难负样本挖掘的工作流程。

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tupleclass DynamicNegativeMiner:def __init__(self, n_samples: int = 100, update_interval: int = 100):self.n_samples = n_samplesself.update_interval = update_intervalself.initial_scores = {}self.current_scores = {}self.replacement_count = 0self.score_history = []def should_replace(self, neg_idx: int) -> bool:"""判断是否需要替换负样本"""current = self.current_scores[neg_idx]initial = self.initial_scores[neg_idx]return (abs(current) < 0.8) and (current * 1.15 < initial)def get_replacement_range(self) -> Tuple[int, int]:"""获取替换样本的索引范围"""i = self.replacement_count + 1start_idx = (i - 1) * self.n_samples + 10end_idx = i * self.n_samples + 10return start_idx, end_idxdef simulate_training():"""模拟训练过程"""# 初始化miner = DynamicNegativeMiner(n_samples=5, update_interval=3)iterations = 15# 初始负样本negative_indices = list(range(5))# 记录训练过程score_history = []replacement_history = []# 模拟训练迭代for iter_idx in range(iterations):# 模拟模型逐渐学习负样本的过程# 随着训练进行,得分会降低(模型越来越确信这些是负样本)base_score = 0.9 * np.exp(-0.2 * iter_idx)scores = base_score + np.random.normal(0, 0.1, len(negative_indices))# 首次记录初始得分for idx, score in zip(negative_indices, scores):if idx not in miner.initial_scores:miner.initial_scores[idx] = scoreminer.current_scores[idx] = scorescore_history.append(scores.mean())# 每隔update_interval次迭代检查是否需要替换if (iter_idx + 1) % miner.update_interval == 0:replacements_needed = []for i, neg_idx in enumerate(negative_indices):if miner.should_replace(neg_idx):replacements_needed.append(i)# 如果需要替换if replacements_needed:miner.replacement_count += 1start_idx, end_idx = miner.get_replacement_range()new_samples = list(range(start_idx, start_idx + len(replacements_needed)))for i, pos in enumerate(replacements_needed):negative_indices[pos] = new_samples[i]replacement_history.append(iter_idx)return score_history, replacement_history# 运行模拟并可视化结果
score_history, replacement_history = simulate_training()plt.figure(figsize=(10, 6))
plt.plot(score_history, 'b-', label='平均得分')
plt.scatter(replacement_history, [score_history[i] for i in replacement_history], color='red', marker='x', s=100, label='负样本替换点')
plt.axhline(y=0.8, color='g', linestyle='--', label='阈值 (0.8)')
plt.grid(True)
plt.xlabel('训练迭代次数')
plt.ylabel('负样本平均得分')
plt.title('动态硬负挖掘过程演示')
plt.legend()# 在替换点添加标注
for idx in replacement_history:plt.annotate(f'第{idx//3}次替换', xy=(idx, score_history[idx]),xytext=(10, 10), textcoords='offset points')

让我解释这个示例的关键部分:

  1. 负样本评估机制:
  • 记录每个负样本的初始得分和当前得分
  • 使用两个条件判断是否需要替换:
    abs(current) < 0.8  # 得分较低,说明模型已经较好地学习了
    current * 1.15 < initial  # 相比初始状态有显著改善
    
  1. 替换策略:
  • 每3次迭代检查一次负样本状态
  • 计算替换样本的范围:
    start_idx = (i - 1) * n_samples + 10
    end_idx = i * n_samples + 10
    
  • 对需要替换的样本进行批量更新
  1. 可视化显示:
  • 蓝线:负样本的平均得分变化
  • 红色×:发生替换的时间点
  • 绿色虚线:得分阈值(0.8)

从图中可以看到:

  1. 随着训练进行,负样本的得分逐渐降低(模型学习效果提升)
  2. 当得分降到一定程度时,触发替换机制
  3. 替换后得分会暂时升高(新的挑战性负样本)
  4. 这个循环持续进行,保持训练的难度

这种方法的优势:

  1. 自适应:根据模型的学习进度自动调整负样本
  2. 持续挑战:不断提供新的困难样本
  3. 效率高:避免在简单负样本上浪费计算资源

http://www.mrgr.cn/news/63357.html

相关文章:

  • 信息论与熵information and entropy
  • Rust 力扣 - 1456. 定长子串中元音的最大数目
  • 刘艳兵-DBA016-在您的数据库中,SALES表存在于SH用户中,并且启用了统一审计。作为DBA,您成功执行了以下指令:
  • 【neo4j】 neo4j cypher单一语句 optional 可选操作的技巧
  • Unity BesHttp插件修改Error log的格式
  • 【微服务】Nacos 注册中心
  • 2024 Rust现代实用教程 流程控制与函数
  • 递归到分治
  • 显示器不亮?解决“显示器不支持当前的输入时序,请将输入时序更改为 1920x1080, 60Hz”的终极指南
  • 别再盲目选购随身WiFi了!一文教你精准挑选最适合自己的随身WiFi!随身wifi哪个牌子的最好用?
  • o1驾驶无人机后空翻,OpenAI开发者日惊掉下巴!2分钟爆改代码写App
  • Vite学习之模式
  • AI实践-PyTorch-CNN-手写数字识别
  • 多线程在打包工具中的运用
  • 5分钟搞定:Spring AI支持SpringBoot快速构建人工智能AI应用_springai_springboot_AI应用
  • jlink识别不到gd32@
  • 连续11年领跑行业 凯迪仕智能锁双11再次稳居全渠道销量第一
  • 鸿蒙HarmonyOS应用开发者(基础+高级)认证
  • jmeter结合ansible分布式压测--准备工作
  • 青少年编程与数学 02-003 Go语言网络编程 03课题、网络编程协议
  • 通义灵码上新功能:用代码画流程图
  • 仓库物品带下拉提示搜索与开单自定义数量和备注带提交反馈单页功能
  • 充电宝哪一款最实用?2024年推荐五款性价比最高选择,附选购攻略
  • R语言贝叶斯
  • LeetCode 热题 100之二叉树
  • 语音IC方案,在交通信号灯语音提示器的应用解析,NV040D