当前位置: 首页 > news >正文

BERT解析

BERT项目
我在BERT添加注释和部分推理代码
main.py

vocab = WordVocab.load_vocab(args.vocab_path)#加载vocab

请添加图片描述
那么这个加载的二进制是什么呢?

1. 加载数据集

继承关系:TorchVocab --> Vocab --> WordVocab

  • TorchVocab

该类主要是定义了一个词典对象,包含如下三个属性:

freqs:是一个collections.Counter对象,能够记录数据集中不同token所出现的次数

stoi:是一个collections.defaultdict对象,将数据集中的token映射到不同的index

itos:是一个列表,保存了从index到token的映射信息

  • Vocab

Vocab继承TorchVocab,该类主要定义了一些特殊token的表示

这里用到了一个装饰器,简单地说:装饰器就是修改其他函数的功能的函数。这里包含了一个序列化的操作

  • WordVocab

WordVocab继承自Vocab,里面包含了两个方法to_seqfrom_seq分别是将token转换成index和将index转换成token表示

2.datasets.py

BERTDataset这个类有——个方法

  1. init():初始化BERTDataset类。
  2. len():返回数据集的大小,即语料库的行数。
  3. getitem():根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。
  4. random_word():对给定句子中的单词进行随机处理,用于生成BERT的输入和标签。
  5. random_sent():随机决定是否交换下一句,用于训练BERT的下一句预测任务。
  6. get_corpus_line():根据索引获取语料库中的句子对。
  7. get_random_line():获取语料库中的一个随机句子。
  • init()
    def __init__(self,corpus_path,vocab,seq_len,encoding="utf-8",corpus_lines=None,on_memory=False,):"""初始化BERTDataset类。     如果on_memory为True,则将语料库加载到内存中。否则,计算语料库的行数。参数:corpus_path (str): 语料库文件的路径。vocab (Vocab): 词汇表对象,用于将单词映射到索引。seq_len (int): 序列长度,BERT输入序列的最大长度。encoding (str, optional): 文件编码,默认为'utf-8'。corpus_lines (int, optional): 语料库行数,如果提供,则在初始化时不会计算。on_memory (bool, optional): 是否将整个语料库加载到内存中,默认为True。"""self.vocab = vocabself.seq_len = seq_lenself.on_memory = on_memoryself.corpus_lines = corpus_linesself.corpus_path = corpus_pathself.encoding = encodingwith open(corpus_path, "r", encoding=encoding) as f:if self.corpus_lines is None and not on_memory:for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):self.corpus_lines += 1if on_memory:self.lines = [line[:-1].split("\t")for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]print('第一行的内容',self.lines[0])print("第一行的数量:",len(self.lines[0]))self.corpus_lines = len(self.lines)"""第一行的内容 ['Cuba to Get Rid of Dollars After a Decade', " HAVANA (Reuters) Cubans rushed to change dollars into  local pesos on Tuesday as President Fidel Castro's communist  government prepared to pull the U.S. currency from circulation  more than a decade after it was legalized here."]第一行的数量: 2"""if not on_memory:self.file = open(corpus_path, "r", encoding=encoding)self.random_file = open(corpus_path, "r", encoding=encoding)for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):self.random_file.__next__()

corpus.small中前三行的内容:

请添加图片描述

所以__init__方法就是将句子分为上下两句,存储在列表中

  • len()
    def __len__(self):"""返回数据集的大小,即语料库的行数。返回:int: 语料库的行数。"""return self.corpus_lines

介绍getitem()之前,我们要先看**random_sent()random_word()**这两个函数

    def __getitem__(self, item):"""根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。参数:item (int): 数据项的索引。返回:dict: 包含BERT输入序列、标签、段标签和下一句预测标签的字典。"""t1, t2, is_next_label = self.random_sent(item)t1_random, t1_label = self.random_word(t1)t2_random, t2_label = self.random_word(t2)
  • random_sent()

在这个函数里面我们先看get_corpus_line()函数和get_random_line()

  • get_corpus_line()

t1和t2是表示上半个句子和下半个句子,用于训练BERT的下一句预测任务。

    def get_corpus_line(self, item):"""根据索引获取语料库中的句子对。参数:item (int): 数据项的索引。返回:tuple: 一个句子对。"""if self.on_memory:#返回一句话中的上半和下半return self.lines[item][0], self.lines[item][1]else:line = self.file.__next__()if line is None:self.file.close()self.file = open(self.corpus_path, "r", encoding=self.encoding)line = self.file.__next__()t1, t2 = line[:-1].split("\t")return t1, t2
  • get_random_line()

随机返回一个句子的下半

现在我们来看**random_sent()**函数

    def random_sent(self, index):"""随机决定是否交换下一句,用于训练BERT的下一句预测任务。参数:index (int): 数据项的索引。返回:tuple: 两个句子和一个标签,指示是否是下一句。"""t1, t2 = self.get_corpus_line(index)# output_text, label(isNotNext:0, isNext:1)if random.random() > 0.5:return t1, t2, 1else:return t1, self.get_random_line(), 0

现在我们来看**random_word()**函数

  • random_word()

85%的概率返回单词和单词在vocab字典中的映射,如果找不到返回unk的index

剩下的概率分别返回mask,随机值,不改变

    def random_word(self, sentence):"""对给定句子中的单词进行随机处理,用于生成BERT的输入和标签。参数:sentence (str): 输入的句子。返回:tuple: 处理后的单词列表和对应的标签列表。"""tokens = sentence.split()output_label = []for i, token in enumerate(tokens):prob = random.random()if prob < 0.15:prob /= 0.15# 80% randomly change token to mask tokenif prob < 0.8:tokens[i] = self.vocab.mask_index# 10% randomly change token to random tokenelif prob < 0.9:tokens[i] = random.randrange(len(self.vocab))# 10% randomly change token to current tokenelse:tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))else:# self.vocab.stoi:单词到index的映射# 查询不到返回unk的indextokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)output_label.append(0)return tokens, output_label

现在回头再来看getitem

  • getitem()
    def __getitem__(self, item):"""根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。参数:item (int): 数据项的索引。返回:dict: 包含BERT输入序列、标签、段标签和下一句预测标签的字典。"""t1, t2, is_next_label = self.random_sent(item)t1_random, t1_label = self.random_word(t1)t2_random, t2_label = self.random_word(t2)"""一个句子的头部加cls,尾部加eos标志label前面加pad,后面加pad"""# [CLS] tag = SOS tag, [SEP] tag = EOS tagt1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]t2 = t2_random + [self.vocab.eos_index]t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]t2_label = t2_label + [self.vocab.pad_index]# segment_label表示当前是第一句话还是第二句话,position的一部分segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[: self.seq_len]bert_input = (t1 + t2)[: self.seq_len]bert_label = (t1_label + t2_label)[: self.seq_len]padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)output = {"bert_input": bert_input,"bert_label": bert_label,"segment_label": segment_label,"is_next": is_next_label,}return {key: torch.tensor(value) for key, value in output.items()}

2. 训练代码

模型代码网上有很多,不再赘述

3. 推理代码

补充推理部分的代码

import torch
from bert_pytorch.model.bert import BERT
from bert_pytorch.model import BERTLM
from bert_pytorch.dataset import WordVocab
import numpy as npdef process(sentence: str) -> tuple:"""输入预处理,将句子转化为对应的id:param str sentence: 输入的句子:return tuple: [token,label]"""sentence = "hello world"tokens = sentence.split()output_label = []for i, token in enumerate(tokens):tokens[i] = vocab.stoi.get(token, vocab.unk_index)output_label.append(0)return tokens, output_label
def infer(s1:str, s2:str)->tuple:input_t1, _ = process("hello")input_t2, _ = process("World")t1 = [vocab.sos_index] + input_t1 + [vocab.eos_index]t2 = [vocab.sos_index] + input_t2 + [vocab.eos_index]# segment_label表示当前是第一句话还是第二句话,position的一部分segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:seq_len]bert_input = (t1 + t2)[:seq_len]bert_input=np.array(bert_input)bert_input=bert_input[None,:]bert_input=torch.from_numpy(bert_input)segment_label=np.array(segment_label)segment_label=torch.from_numpy(segment_label)next_sent_output, mask_lm_output = model(bert_input, segment_label)mask_lm_output = mask_lm_output.transpose(1, 2)# 待补充# 2020 256 8 8
# 2020是len(vocab),自己去看vocab.py里面的
seq_len = 20
f = torch.load("./output/bert.model.ep0")
bert = BERT(2020, 256, 8, 8)
model = BERTLM(bert, 2020)
model.load_state_dict(f)
# 加载词汇表,方便将预测的词转化为对应的id
vocab = WordVocab.load_vocab(r".\data\vocab.small")
infer("hello","world")

http://www.mrgr.cn/news/77846.html

相关文章:

  • 【数据分享】中国价格统计年鉴(2013-2024) PDF
  • uni-app打包H5自定义微信分享
  • 实验四:构建园区网(OSPF 动态路由)
  • ROS机器视觉入门:从基础到人脸识别与目标检测
  • 深度学习:GPT-1的MindSpore实践
  • 【单元测试】【Android】JUnit 4 和 JUnit 5 的差异记录
  • 鸿蒙进阶-状态管理之@Prop@Link
  • 【老白学 Java】Warship v2.0(三)
  • 增量预训练(Pretrain)样本拼接篇
  • Gate学习(6) 指令学习3
  • WPF异步UI交互功能的实现方法
  • cangjie (仓颉) vscode环境搭建
  • .NET9 - 新功能体验(二)
  • 使用bcc/memleak定位C/C++应用的内存泄露问题
  • #Verilog HDL# 谈谈代码中如何跨层次引用
  • 下载安装Android Studio
  • #Verilog HDL# Verilog中的ifdef/ifndef/else等用法
  • 每日一练:位运算-消失的两个数字
  • CNN—LeNet:从0开始神经网络学习,实战MNIST和CIFAR10~
  • 第三十四篇 MobileNetV1、V2、V3模型解析
  • 【计算机网络】数据链路层
  • 算法(Algorithm)
  • Playwright(Java版) - 7: Playwright 页面对象模型(POM)
  • 使用 Spring Boot 和 GraalVM 的原生镜像
  • win10局域网加密共享设置
  • 《计算力学学报》