NLP: SBERT介绍及sentence-transformers库的使用
1. Sentence-BERT
Sentence-BERT(简写SBERT)模型是BERT模型最有趣的变体之一,通过扩展预训练的BERT模型来获得固定长度的句子特征,主要用于句子对分类、计算两个句子之间的相似度任务。
1.1 计算句子特征
SBERT模型同样是将句子标记送入预训练的BERT模型来获取句子特征的,但这里并不使用 R [ C L S ] R_{[CLS]} R[CLS]作为最终的句子特征。在SBERT中,通过汇聚所有标记的特征来计算整个句子的特征。具体的汇聚方法有两种:平均汇聚和最大汇聚。
- 平均汇聚:使用平均汇聚来获取句子特征。这种方法得到的句子的特征将包含所有词语(Token)的意义。
- 最大汇聚:使用最大汇聚来获取句子特征。这种方法得到的句子的特征将仅包含重要词语(Token)的意义。
1.2 SBERT架构
SBERT模型使用二元组网络架构来执行以一对句子作为输入的任务,并使用三元组网络架构来实现三元组损失函数。
1.2.1 使用二元组网络架构的SBERT模型
SBERT通过二元组网络(两个共享同样权重的相同网络)架构对执行句子对任务的预训练的BERT模型进行微调。句子对任务具体包括以下两种:
- 句子对分类任务: 判断句子对是否相似。相似则返回1,不相似则返回0。其SBERT模型架构为:
- 句子对回归任务:计算两个给定句子之间的语义相似度。其对应的SBERT架构为:
1.2.2 使用三元组网络架构的SBERT模型
三元组网络架构的SBERT模型的任务计算出一个特征,使锚定句和正向句之间的相似度高,锚定句和负向句之间的相似度低。其架构如下:
2. 计算文本相似度
2.1 bi-encoder VS cross-encoder
bi-encoder和cross-encoder是语义匹配、文本相似度、信息检索场景下下常用的两种模型架构。这两者都基于深度学习模型(如BERT等)进行编码和比较文本之间的相似度,但它们在计算方式、效率和适用场景上有显著的区别。
2.1.1 bi-encoder
bi-encoder是一种独立编码方式,即输入的两个文本会被分别编码为独立的向量,然后通过计算这两个向量的相似度来判断文本之间的关系。使用bi-encoder方式计算文本相似度的案例如下:
from sentence_transformers import SentenceTransformer
#加载预训练的sentence transformer模型
model = SentenceTransformer('all-MiniLM-L6-v2')
sentences=["这个商品挺好用的","这个商品一点也不好用"]
embeddings=model.encode(sentences)
similarity=model.similarity(embeddings[0],embeddings[1])
print(similarity) #0.5868
2.1.2 cross-encoder
cross-encoder是一种联合编码方式,即将两个文本拼接在一起作为模型的输入,模型会通过对两个文本的联合表示来直接输出一个相似度分数。这种方式可以更好地捕捉两个文本之间的复杂交互信息,因此在诸如问答匹配、精确文本相似度计算等需要细粒度判断的任务上表现更好。具体使用方式如下:
from sentence_transformers.cross_encoder import CrossEncoder
model=CrossEncoder("cross-encoder/stsb-distilroberta-base")
query="这个产品挺好用的"
corpus=["这个产品很好","这个产品的设计有很大问题","这个产品不好用"]
ranks=model.rank(query,corpus)
for rank in ranks:print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")
3 微调SBERT
接下来我们使用STSB数据集对SBERT模型进行微调。具体代码如下
from datasets import load_dataset
from sentence_transformers import losses
from sentence_transformers import (SentenceTransformer,SentenceTransformerTrainingArguments,SentenceTransformerTrainer,
)
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator,SimilarityFunction
from datasets import load_datasettrain = load_dataset("sentence-transformers/stsb",split='train')
dev = load_dataset("sentence-transformers/stsb",split='validation')
test= load_dataset("sentence-transformers/stsb",split='test')model=SentenceTransformer('FacebookAI/xlm-roberta-base')loss=losses.CoSENTLoss(model=model)args=SentenceTransformerTrainingArguments(output_dir='models/model1',num_train_epochs=1,per_device_train_batch_size=16,per_device_eval_batch_size=16,warmup_ratio=0.1,eval_strategy='steps',eval_steps=100,save_strategy='steps',save_total_limit=2,bf16=False,)dev_evaluator=EmbeddingSimilarityEvaluator(sentences1=dev['sentence1'],sentences2=dev['sentence2'],scores=dev['score'],main_similarity=SimilarityFunction.COSINE,name='dev-evaluator')dev_evaluator(model)trainer=SentenceTransformerTrainer(model=model,args=args,train_dataset=train,eval_dataset=dev,loss=loss,evaluator=dev_evaluator)
trainer.train() test_evaluator=EmbeddingSimilarityEvaluator(sentences1=test['sentence1'],sentences2=test['sentence2'],scores=test['score'],main_similarity=SimilarityFunction.COSINE,name='test-evaluator')
test_evaluator(model)
model.save_pretrained('models/model1')
关于上述代码,需要说明以下几点:
- 训练和评估SBERT的数据类型必须是
datasets.Dataset
和datasets.DatasetDict
。 - 数据集的格式必须和损失函数、评估器相匹配。如果损失函数需要标签字段,那么数据集必须有“label”或“score”字段;其他名称非“label”或“score”的字段将自动归属于Inputs字段。所以在进行后续步骤时,必须将数据集中的无法标签删除,同时要保证数据集中的字段顺序与对应损失函数中要求的顺序一致。
- 需要根据具体的任务以及数据集的形式选择合适的损失函数,没有哪种损失函数可以解决所有的问题。SBERT提供的损失函数列表如下:
https://www.sbert.net/docs/sentence_transformer/loss_overview.html - 微调后的模型可以和其他预训练的模型一样使用,比如计算文本相似度,这里不再赘述。
参考资料
- BERT基础教程: Transformer大模型实战
- https://baijiahao.baidu.com/s?id=1801193891938395467
- https://www.sbert.net