BERT 训练代码
paperreading
本文字数:1.5k 字 | 阅读时长 ≈ 6 min

BERT 训练代码

paperreading
本文字数:1.5k 字 | 阅读时长 ≈ 6 min

本文代码和实际情况有所出入,写本文主要是通过文章来帮助刚入门的同学快速理解 BERT 原文中的思路,文章分为三部分:数据集的构建和选取,网络结构和 loss 计算

1. 数据集的构建和提取

1.1 数据集的格式

在 BERT 原文中采用了 Bookcorpus 和 Wikipedia 数据集,并且是 document-level 的,也就是说他们的 sequence 选取是在文档中截取的连续 token,类似于如下形式

i wish i had a better answer to that question .

然后采用 WordPiece 的方法对其进行截取,截取完之后得到如下 token,一共 9 个 token,并用数字对其编号

i wish had a better answer to that question
4, 5, 6, 7, 8, 9, 10, 11, 12

至于为什么从 4 开始,这是因为原文中有「CLS」,「MASK」,「Seq」等 token,我们对其编码 0,1,2 等。以此类推,对于整个数据集,语料库就是这样构建的,按照原文,一共 30000 个 token,即编号从 0 一直到 30000

1.2 数据集的提取

构建完数据集之后,我们需要对其进行提取,需要编写 dataset.py 文件,这里展示 getitem 函数的一些核心内容

  1. 首先从文档中提取两句话(这里为了方便,我只提取一句话,但是将这句话分为两段),返回 t1,t2 和 is_next_label,其中 is_next_label 表示这两个 sentence 是不是连续的,代码如下
t1, t2, is_next_label = self.random_sent(item)

'''
t1: i have been taken over
t2: car for a manuscript , then you tell me .\n
is_next_label: 0
'''
  1. 随后提取 t1 和 t2 的 token 编号(就是之前介绍的数字编号)以及进行 mask 操作, 看最后的返回,t1_random 就是返回的编号,比如在这里 i 这个单词对应的索引编号为 12,have 的编号为 37,t1_label 都为 0 表示这个 t1_random 没有任何单词被 mask 过,我们再来看 t2_random,其中有一个为 4,代表被 mask 了(因为事先已经将 mask 编号为了 4),这里 tell 和".\n"被 mask 了,t2_label 里面有两个非 0 元素,182 和 5,表示这两个被 mask 的 token 被 mask 实际编号为 182 和 5,我们之后 MLM 预测时就要预测这两个分类值。
# mask以及mask位置原来的label
t1_random, t1_label = self.random_word(t1)
t2_random, t2_label = self.random_word(t2)

# random_word的核心代码
for i, token in enumerate(tokens):
    prob = random.random()
    if prob < 0.15:  # 15%的概率采取mask措施
        prob /= 0.15

        if prob < 0.8:
            tokens[i] = self.vocab.mask_index  # 80% change token to mask token
        elif prob < 0.9:
            tokens[i] = random.randrange(len(self.vocab))  # 10% change token to random token
        else:
            tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) # 10% change token to current token
        output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
    else:
        tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
        output_label.append(0)

'''
t1_random: [12, 37, 66, 435, 73]
t1_label: [0, 0, 0, 0, 0]
t2_random: [430, 26, 10, 6285, 7, 63, 19, 4, 39, 4]
t2_label: [0, 0, 0, 0, 0, 0, 0, 182, 0, 5]
'''
  1. 这一步是额外加的,因为我这里的 sentence 太短了,我的 toy example 只设置了 max sentence 为 20,所以会出现没有 mask 的情况,也就是上面的 t1_label,这样计算 MLM 的时候会出现 Nan 问题(但是在训练大模型中不会出现这个问题,因为大模型的 token 很长,有 512 或者 1024 等,也就是说有 500+的单词,而且 mask 的概率也是提前设置的,所以会 mask 固定的数量,不会存在没被 mask 的情况)。所以我们接下来 check 以下 t1_label 和 t2_label 是否有没有被 mask 的,如果没有,将没有被 mask 的 sentence 至少 mask 一个。这里只有 t1_label 存在非 mask 的元素,所以 check 后只有 t1 变化了
 # 因为这里的token很短,会出现没有token被mask的情况
# 这样计算mlm会出现nan,因此检查一下,让label中至少一个被mask
t1_random, t1_label = self.check(t1_random, t1_label)
t2_random, t2_label = self.check(t2_random, t2_label)

'''
t1_random: [12, 4, 66, 435, 73]
t1_label: [0, 1, 0, 0, 0]
t2_random: [430, 26, 10, 6285, 7, 63, 19, 4, 39, 4]
t2_label: [0, 0, 0, 0, 0, 0, 0, 182, 0, 5]
'''
  1. 这里我们加上「CLS」和「SEP」这两个 token,他们在我们的设置中编号分别为 3 和 2,然后将 t1_label 和 t2_label 也进行相应更改,因为不是 mask 区域,所以加入 0 即可
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
t2 = t2_random + [self.vocab.eos_index]

t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
t2_label = t2_label + [self.vocab.pad_index]
'''
t1_random: [3, 12, 4, 66, 435, 73, 2]
t1_label: [0, 0, 1, 0, 0, 0, 0]
t2_random: [430, 26, 10, 6285, 7, 63, 19, 4, 39, 4, 2]
t2_label: [0, 0, 0, 0, 0, 0, 0, 182, 0, 5, 0]
'''
  1. 最后定义一下 segment_label,即句子的标号,由于我们设置了最大 sentence 的长度,截取一下即可,注意如果整个 sentence 达不到预先定义的长度,比如本例一共 18,我预先设置的 max 为 20,就会进行 padding,将 18 填充到 20,padding 的编号默认为 0,注意箭头后面是我特意标定的补充值
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
bert_input = (t1 + t2)[:self.seq_len]
bert_label = (t1_label + t2_label)[:self.seq_len]
padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
bert_input.extend(padding)  # token padding
bert_label.extend(padding)  # mask label padding
segment_label.extend(padding)  # segment label padding

'''
bert_input: [3, 12, 4, 66, 435, 73, 2, 430, 26, 10, 6285, 7, 63, 19, 4, 39, 4, 2, 0, -> 0, 0]
bert_label: [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 182, 0, 5, 0, -> 0, 0]
segment_label: [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, -> 0, 0]
'''

至此数据集的预处理和提取就好了,可以输入到网络进行训练

2. 网络结构

网络结构分为两种

3. Loss 计算

网络输出有两个,分别是 MLM 的输出和 NSP 的输出,其中二者均是分类问题

其中 MLM 输出的是 mask 部分的结果,由于非 mask 部分的编号为 0,所以 loss=nn.NLLLoss(ignore_index=0)要忽略掉 0 类,如果在网络输出没用 logsoftmax,这里要用 crossentropy 损失

NSP 输出的是 0 或者 1,loss=nn.NLLLoss()就不需要忽略掉 0 类了