BERT解析
BERT项目
我在BERT添加注释和部分推理代码
main.py
vocab = WordVocab.load_vocab(args.vocab_path)#加载vocab
那么这个加载的二进制是什么呢?
1. 加载数据集
继承关系:TorchVocab --> Vocab --> WordVocab
- TorchVocab
该类主要是定义了一个词典对象,包含如下三个属性:
freqs
:是一个collections.Counter
对象,能够记录数据集中不同token所出现的次数
stoi
:是一个collections.defaultdict
对象,将数据集中的token映射到不同的index
itos
:是一个列表,保存了从index到token的映射信息
- Vocab
Vocab继承TorchVocab,该类主要定义了一些特殊token的表示
这里用到了一个装饰器,简单地说:装饰器就是修改其他函数的功能的函数。这里包含了一个序列化的操作
- WordVocab
WordVocab继承自Vocab,里面包含了两个方法to_seq
和from_seq
分别是将token转换成index和将index转换成token表示
2.datasets.py
BERTDataset这个类有——个方法
- init():初始化BERTDataset类。
- len():返回数据集的大小,即语料库的行数。
- getitem():根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。
- random_word():对给定句子中的单词进行随机处理,用于生成BERT的输入和标签。
- random_sent():随机决定是否交换下一句,用于训练BERT的下一句预测任务。
- get_corpus_line():根据索引获取语料库中的句子对。
- get_random_line():获取语料库中的一个随机句子。
- init()
def __init__(self,corpus_path,vocab,seq_len,encoding="utf-8",corpus_lines=None,on_memory=False,):"""初始化BERTDataset类。 如果on_memory为True,则将语料库加载到内存中。否则,计算语料库的行数。参数:corpus_path (str): 语料库文件的路径。vocab (Vocab): 词汇表对象,用于将单词映射到索引。seq_len (int): 序列长度,BERT输入序列的最大长度。encoding (str, optional): 文件编码,默认为'utf-8'。corpus_lines (int, optional): 语料库行数,如果提供,则在初始化时不会计算。on_memory (bool, optional): 是否将整个语料库加载到内存中,默认为True。"""self.vocab = vocabself.seq_len = seq_lenself.on_memory = on_memoryself.corpus_lines = corpus_linesself.corpus_path = corpus_pathself.encoding = encodingwith open(corpus_path, "r", encoding=encoding) as f:if self.corpus_lines is None and not on_memory:for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):self.corpus_lines += 1if on_memory:self.lines = [line[:-1].split("\t")for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]print('第一行的内容',self.lines[0])print("第一行的数量:",len(self.lines[0]))self.corpus_lines = len(self.lines)"""第一行的内容 ['Cuba to Get Rid of Dollars After a Decade', " HAVANA (Reuters) Cubans rushed to change dollars into local pesos on Tuesday as President Fidel Castro's communist government prepared to pull the U.S. currency from circulation more than a decade after it was legalized here."]第一行的数量: 2"""if not on_memory:self.file = open(corpus_path, "r", encoding=encoding)self.random_file = open(corpus_path, "r", encoding=encoding)for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):self.random_file.__next__()
corpus.small中前三行的内容:
所以__init__方法就是将句子分为上下两句,存储在列表中
- len()
def __len__(self):"""返回数据集的大小,即语料库的行数。返回:int: 语料库的行数。"""return self.corpus_lines
介绍getitem()之前,我们要先看**random_sent()和random_word()**这两个函数
def __getitem__(self, item):"""根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。参数:item (int): 数据项的索引。返回:dict: 包含BERT输入序列、标签、段标签和下一句预测标签的字典。"""t1, t2, is_next_label = self.random_sent(item)t1_random, t1_label = self.random_word(t1)t2_random, t2_label = self.random_word(t2)
- random_sent()
在这个函数里面我们先看get_corpus_line()函数和get_random_line()
- get_corpus_line()
t1和t2是表示上半个句子和下半个句子,用于训练BERT的下一句预测任务。
def get_corpus_line(self, item):"""根据索引获取语料库中的句子对。参数:item (int): 数据项的索引。返回:tuple: 一个句子对。"""if self.on_memory:#返回一句话中的上半和下半return self.lines[item][0], self.lines[item][1]else:line = self.file.__next__()if line is None:self.file.close()self.file = open(self.corpus_path, "r", encoding=self.encoding)line = self.file.__next__()t1, t2 = line[:-1].split("\t")return t1, t2
- get_random_line()
随机返回一个句子的下半
现在我们来看**random_sent()**函数
def random_sent(self, index):"""随机决定是否交换下一句,用于训练BERT的下一句预测任务。参数:index (int): 数据项的索引。返回:tuple: 两个句子和一个标签,指示是否是下一句。"""t1, t2 = self.get_corpus_line(index)# output_text, label(isNotNext:0, isNext:1)if random.random() > 0.5:return t1, t2, 1else:return t1, self.get_random_line(), 0
现在我们来看**random_word()**函数
- random_word()
85%的概率返回单词和单词在vocab字典中的映射,如果找不到返回unk的index
剩下的概率分别返回mask,随机值,不改变
def random_word(self, sentence):"""对给定句子中的单词进行随机处理,用于生成BERT的输入和标签。参数:sentence (str): 输入的句子。返回:tuple: 处理后的单词列表和对应的标签列表。"""tokens = sentence.split()output_label = []for i, token in enumerate(tokens):prob = random.random()if prob < 0.15:prob /= 0.15# 80% randomly change token to mask tokenif prob < 0.8:tokens[i] = self.vocab.mask_index# 10% randomly change token to random tokenelif prob < 0.9:tokens[i] = random.randrange(len(self.vocab))# 10% randomly change token to current tokenelse:tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))else:# self.vocab.stoi:单词到index的映射# 查询不到返回unk的indextokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)output_label.append(0)return tokens, output_label
现在回头再来看getitem
- getitem()
def __getitem__(self, item):"""根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。参数:item (int): 数据项的索引。返回:dict: 包含BERT输入序列、标签、段标签和下一句预测标签的字典。"""t1, t2, is_next_label = self.random_sent(item)t1_random, t1_label = self.random_word(t1)t2_random, t2_label = self.random_word(t2)"""一个句子的头部加cls,尾部加eos标志label前面加pad,后面加pad"""# [CLS] tag = SOS tag, [SEP] tag = EOS tagt1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]t2 = t2_random + [self.vocab.eos_index]t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]t2_label = t2_label + [self.vocab.pad_index]# segment_label表示当前是第一句话还是第二句话,position的一部分segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[: self.seq_len]bert_input = (t1 + t2)[: self.seq_len]bert_label = (t1_label + t2_label)[: self.seq_len]padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)output = {"bert_input": bert_input,"bert_label": bert_label,"segment_label": segment_label,"is_next": is_next_label,}return {key: torch.tensor(value) for key, value in output.items()}
2. 训练代码
模型代码网上有很多,不再赘述
3. 推理代码
补充推理部分的代码
import torch
from bert_pytorch.model.bert import BERT
from bert_pytorch.model import BERTLM
from bert_pytorch.dataset import WordVocab
import numpy as npdef process(sentence: str) -> tuple:"""输入预处理,将句子转化为对应的id:param str sentence: 输入的句子:return tuple: [token,label]"""sentence = "hello world"tokens = sentence.split()output_label = []for i, token in enumerate(tokens):tokens[i] = vocab.stoi.get(token, vocab.unk_index)output_label.append(0)return tokens, output_label
def infer(s1:str, s2:str)->tuple:input_t1, _ = process("hello")input_t2, _ = process("World")t1 = [vocab.sos_index] + input_t1 + [vocab.eos_index]t2 = [vocab.sos_index] + input_t2 + [vocab.eos_index]# segment_label表示当前是第一句话还是第二句话,position的一部分segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:seq_len]bert_input = (t1 + t2)[:seq_len]bert_input=np.array(bert_input)bert_input=bert_input[None,:]bert_input=torch.from_numpy(bert_input)segment_label=np.array(segment_label)segment_label=torch.from_numpy(segment_label)next_sent_output, mask_lm_output = model(bert_input, segment_label)mask_lm_output = mask_lm_output.transpose(1, 2)# 待补充# 2020 256 8 8
# 2020是len(vocab),自己去看vocab.py里面的
seq_len = 20
f = torch.load("./output/bert.model.ep0")
bert = BERT(2020, 256, 8, 8)
model = BERTLM(bert, 2020)
model.load_state_dict(f)
# 加载词汇表,方便将预测的词转化为对应的id
vocab = WordVocab.load_vocab(r".\data\vocab.small")
infer("hello","world")