当前位置: 首页 > news >正文

NLP: SBERT介绍及sentence-transformers库的使用

1. Sentence-BERT

  Sentence-BERT(简写SBERT)模型是BERT模型最有趣的变体之一,通过扩展预训练的BERT模型来获得固定长度的句子特征,主要用于句子对分类、计算两个句子之间的相似度任务。

1.1 计算句子特征

  SBERT模型同样是将句子标记送入预训练的BERT模型来获取句子特征的,但这里并不使用 R [ C L S ] R_{[CLS]} R[CLS]作为最终的句子特征。在SBERT中,通过汇聚所有标记的特征来计算整个句子的特征。具体的汇聚方法有两种:平均汇聚和最大汇聚。

  • 平均汇聚:使用平均汇聚来获取句子特征。这种方法得到的句子的特征将包含所有词语(Token)的意义。
  • 最大汇聚:使用最大汇聚来获取句子特征。这种方法得到的句子的特征将仅包含重要词语(Token)的意义。
    在这里插入图片描述

1.2 SBERT架构

  SBERT模型使用二元组网络架构来执行以一对句子作为输入的任务,并使用三元组网络架构来实现三元组损失函数。

1.2.1 使用二元组网络架构的SBERT模型

  SBERT通过二元组网络(两个共享同样权重的相同网络)架构对执行句子对任务的预训练的BERT模型进行微调。句子对任务具体包括以下两种:

  • 句子对分类任务: 判断句子对是否相似。相似则返回1,不相似则返回0。其SBERT模型架构为:
    在这里插入图片描述
  • 句子对回归任务:计算两个给定句子之间的语义相似度。其对应的SBERT架构为:在这里插入图片描述
1.2.2 使用三元组网络架构的SBERT模型

  三元组网络架构的SBERT模型的任务计算出一个特征,使锚定句和正向句之间的相似度高,锚定句和负向句之间的相似度低。其架构如下:
在这里插入图片描述

2. 计算文本相似度

2.1 bi-encoder VS cross-encoder

  bi-encoder和cross-encoder是语义匹配、文本相似度、信息检索场景下下常用的两种模型架构。这两者都基于深度学习模型(如BERT等)进行编码和比较文本之间的相似度,但它们在计算方式、效率和适用场景上有显著的区别。

2.1.1 bi-encoder

  bi-encoder是一种独立编码方式,即输入的两个文本会被分别编码为独立的向量,然后通过计算这两个向量的相似度来判断文本之间的关系。使用bi-encoder方式计算文本相似度的案例如下:

from sentence_transformers import SentenceTransformer
#加载预训练的sentence transformer模型
model = SentenceTransformer('all-MiniLM-L6-v2')
sentences=["这个商品挺好用的","这个商品一点也不好用"]
embeddings=model.encode(sentences)
similarity=model.similarity(embeddings[0],embeddings[1])
print(similarity) #0.5868
2.1.2 cross-encoder

  cross-encoder是一种联合编码方式,即将两个文本拼接在一起作为模型的输入,模型会通过对两个文本的联合表示来直接输出一个相似度分数。这种方式可以更好地捕捉两个文本之间的复杂交互信息,因此在诸如问答匹配、精确文本相似度计算等需要细粒度判断的任务上表现更好。具体使用方式如下:

from sentence_transformers.cross_encoder import CrossEncoder
model=CrossEncoder("cross-encoder/stsb-distilroberta-base")
query="这个产品挺好用的"
corpus=["这个产品很好","这个产品的设计有很大问题","这个产品不好用"]
ranks=model.rank(query,corpus)
for rank in ranks:print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")

3 微调SBERT

  接下来我们使用STSB数据集对SBERT模型进行微调。具体代码如下

from datasets import load_dataset
from sentence_transformers import losses
from sentence_transformers import (SentenceTransformer,SentenceTransformerTrainingArguments,SentenceTransformerTrainer,
)
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator,SimilarityFunction
from datasets import load_datasettrain = load_dataset("sentence-transformers/stsb",split='train')
dev = load_dataset("sentence-transformers/stsb",split='validation')
test= load_dataset("sentence-transformers/stsb",split='test')model=SentenceTransformer('FacebookAI/xlm-roberta-base')loss=losses.CoSENTLoss(model=model)args=SentenceTransformerTrainingArguments(output_dir='models/model1',num_train_epochs=1,per_device_train_batch_size=16,per_device_eval_batch_size=16,warmup_ratio=0.1,eval_strategy='steps',eval_steps=100,save_strategy='steps',save_total_limit=2,bf16=False,)dev_evaluator=EmbeddingSimilarityEvaluator(sentences1=dev['sentence1'],sentences2=dev['sentence2'],scores=dev['score'],main_similarity=SimilarityFunction.COSINE,name='dev-evaluator')dev_evaluator(model)trainer=SentenceTransformerTrainer(model=model,args=args,train_dataset=train,eval_dataset=dev,loss=loss,evaluator=dev_evaluator)   
trainer.train()                        test_evaluator=EmbeddingSimilarityEvaluator(sentences1=test['sentence1'],sentences2=test['sentence2'],scores=test['score'],main_similarity=SimilarityFunction.COSINE,name='test-evaluator')
test_evaluator(model)
model.save_pretrained('models/model1')

关于上述代码,需要说明以下几点:

  • 训练和评估SBERT的数据类型必须是datasets.Datasetdatasets.DatasetDict
  • 数据集的格式必须和损失函数、评估器相匹配。如果损失函数需要标签字段,那么数据集必须有“label”或“score”字段;其他名称非“label”或“score”的字段将自动归属于Inputs字段。所以在进行后续步骤时,必须将数据集中的无法标签删除,同时要保证数据集中的字段顺序与对应损失函数中要求的顺序一致。
  • 需要根据具体的任务以及数据集的形式选择合适的损失函数,没有哪种损失函数可以解决所有的问题。SBERT提供的损失函数列表如下:
    https://www.sbert.net/docs/sentence_transformer/loss_overview.html
  • 微调后的模型可以和其他预训练的模型一样使用,比如计算文本相似度,这里不再赘述。

参考资料

  1. BERT基础教程: Transformer大模型实战
  2. https://baijiahao.baidu.com/s?id=1801193891938395467
  3. https://www.sbert.net

http://www.mrgr.cn/news/46490.html

相关文章:

  • Qt QML专栏目录结构
  • wireshark上没有显示出来rtp协议如何处理
  • 【docker踩坑记录】
  • ctf竞赛
  • 网络协议(八):IP 协议
  • 【WPS】【WORDEXCEL】【VB】实现微软WORD自动更正的效果
  • 基于SpringBoot的校园新闻管理系统 计算机毕业设计选题 Java毕业设计 SpringBoot+Vue 前后端分离 [附源码+安装调试]
  • MAX模型转为las点云模型
  • 响应速度相关知识
  • 汽车胶黏剂市场研究:预计2030年全球市场规模将达到67.4亿美元
  • Apache Flink 配合 Debezium 连接器来捕获 Oracle 数据库变更日志的应用
  • 图像平滑处理
  • 基于vue框架的大学生在线教育jp6jw(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。
  • IDEA 输入英文字体变了的问题
  • 【宽搜】6. leetcode 513 找树左下角的值
  • patch函数前两个参数位
  • c++输出保留n位小数
  • 默认情况下,`QTableView`中的单元格内容是不支持自动换行的,而是将文本截断或者显示省略号。要实现内容自动换行。要用Delegate
  • 鹧鸪云光伏软件全面解析
  • Web3与人工智能的交叉应用探索
  • 【深度学习总结】热力图-Grad-CAM使用
  • whistle使用实践
  • Linux内核 -- 使用 `proc_create_seq` 和 `seq_operations` 快速创建 /proc 文件
  • VAE(与GAN)
  • k8s pod详解使用
  • 【系统架构设计师】案例专题二:系统开发基础考点梳理