GPT-2 训练和测试代码解析
paperreading
本文字数:4.8k 字 | 阅读时长 ≈ 23 min

GPT-2 训练和测试代码解析

paperreading
本文字数:4.8k 字 | 阅读时长 ≈ 23 min

GPT(Generative Pre-training)模型首次由 OpenAI 在 2018 年提出:Improving Language Understanding by Generative Pre-Training,是 Decoder-only 架构,这只包含 Transformer 的 Decoder 部分,由于当前时间步的信息输出只能依赖于之前的信息,不像 BERT 一样可以看到之后的信息,因此称为因果模型,即 casual transformer,但是叫自回归模型(auto-regressive transformer)或者 decoder-only 模型较多

“因果”也源于模型处理信息的方式,即生成文本的方式。GPT 在生成当前文本时,考虑之前所有的文本,这个过程遵循一个因果链,即一个概率模型

本文主要以 GPT-2 来介绍 Decoder-only 模型的训练和推理过程,最后介绍 LoRA 的原理,即一种降低显存来优化语言模型的方法

1. GPT-2 的训练过程

本文介绍的模型主要采用 GPT-2 Medium,$Layers=24, d_{model}=1024$,并且我们参考LoRA的代码进行学习,采用的数据集市 E2E NLG Challenge,根据键值对生成餐馆描述

1.1 数据预处理

1.1.1. 数据处理的基本思路

处理语言的第一步是将一句话基于某种规则进行分词,然后 token 化,例如下面例子

# Input
Using a Transformer network is simple
# Token
['Using', 'a', 'transform', '##er', 'network', 'is', 'simple']
# Id
[7993, 170, 11303, 1200, 2443, 1110, 3014]

其中分词之后的每一个 Token 都有自己的 id,最后 Transformer 的输出就变成了一个分类问题,即预测下一个 Token。上面的句子 Token 化是通过某种方法学来的,这里不做详细赘述。

1.1.2. 本文使用的数据集

知道数据处理的基本思路之后,我们看一下本文的数据格式,

name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil||Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food .

其中 name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil 是键值对,Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food .是生成的描述,我们最终的目的就是通过键值对来生成其描述,数据最终的预处理方法将词 Token 化,然后取其 id

那么 id 如何转化为向量呢,我们有一个词表,包含了所有的 Token,最终输入到网络中时就是根据 Token 的 Id 在词表中进行索引(在 1.2 部分会详细解释)

1.1.3. 输入之前数据的处理

假设我们从数据集中随机抽取一个样本,并且 Token 化,现在我们假设其问题和答案分别如下所示

# question
[3672, 1058, 4285, 4763, 930, 5994, 1058, 2240, 930, 2057, 1058, 3594, 930, 6491, 7955, 1058, 352, 503, 286, 642, 930, 1474, 1058, 575, 3974, 1453, 399, 27106, 2409, 50256]

# answer
[7911, 3187, 4285, 4763, 837, 674, 649, 2240, 5140, 1474, 575, 3974, 1453, 399, 27106, 2409, 764, 49208, 3594, 2057, 290, 40776, 257, 6491, 7955, 286, 352, 503, 286, 642, 837, 836, 705, 83, 1208, 428, 1295, 510, 764, 50256]

那么 input 为二者的拼接,并且我们会设置一个最大序列长度(max sequence length),如果 question+answer 的长度没有达到最大序列长度,我们会将剩余部分补 0,这是因为同时处理多个数据时,他们的长度是不同的,所以我们要将他们补齐到同样的长度。最终的 GT 为 input 右移一位,这是因为当前词的输出只能够看到他前面的词汇,所以第一个位置的输出只能看到 input 的第一个词汇,第二个位置的输出只能看到 input 的第一个和第二个词汇,以此类推,下面是最终的 input,target 和 mask,其中 mask 将 target 对应的 answer 部分设为 1

# input
[3672, 1058, 4285, 4763, 930, 5994, 1058, 2240, 930, 2057, 1058, 3594, 930, 6491, 7955, 1058, 352, 503, 286, 642, 930, 1474, 1058, 575, 3974, 1453, 399, 27106, 2409, 50256, 7911, 3187, 4285, 4763, 837, 674, 649, 2240, 5140, 1474, 575, 3974, 1453, 399, 27106, 2409, 764, 49208, 3594, 2057, 290, 40776, 257, 6491, 7955, 286, 352, 503, 286, 642, 837, 836, 705, 83, 1208, 428, 1295, 510, 764, 50256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# target
[1058, 4285, 4763, 930, 5994, 1058, 2240, 930, 2057, 1058, 3594, 930, 6491, 7955, 1058, 352, 503, 286, 642, 930, 1474, 1058, 575, 3974, 1453, 399, 27106, 2409, 50256, 7911, 3187, 4285, 4763, 837, 674, 649, 2240, 5140, 1474, 575, 3974, 1453, 399, 27106, 2409, 764, 49208, 3594, 2057, 290, 40776, 257, 6491, 7955, 286, 352, 503, 286, 642, 837, 836, 705, 83, 1208, 428, 1295, 510, 764, 50256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# mask
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

1.2 Transformer 的输入

1.2.1. token 的向量化

在数据处理部分,每个句子被分为一个个 token,每个 token 在词表中均有一个对应的 id,我们通过这个 id 可以索引他在词表中的向量

在 padding 之后,每个 batch 里面的句子均被填充为了 512 个 Token,不到 512 Token 的用 0 进行补全,假设 batchsize=2,输入维度则为 [2, 512],此外每个词都有对应的位置编码,让模型知道每个 Token 在句子中的位置,最终从 Token 到向量的过程如下所示

其中 GPT-2 的词表长度为 50257 个 Token,由于 max_sequence_length 为 512,所以位置编码一共 512 个 Token,在 Token 向量化之后输入到 Transformer 的词向量维度为 [2, 512, 1024]

1.3 Transformer 中进行 Attention

embedding 之后的向量维度为 [2, 512, 1024],接下来经过 24 个 Transformer block,我们首先看每个 block 的整体结构,如下图所示,每个 block 分为 attention 部分和 feed forward network 部分,下面我们分别对 Masked Multi-head Attention 和 MLP 部分分别做介绍

1.3.1. Masked Multi-head Attention

我们直接通过一张图展示 Masked Multi-head Attention 这一模块的向量变化

注意上图中,在 softmax 之后加了一个上三角矩阵,这是因为 GPT-2 这种自回归模型是因果模型,当前词只能看到前面的词,看不到后面的词。

self-attention 回顾

下面我们通过一个例子简单回顾一下 self-attention,如下图所示,自注意力机制能够看到所有的词,例如 The 这个词,attention 的结果就是计算让其余所有的词对它的权重其进行加权,然而 Mased attention 只能看到前面的词,同样以这个例子为例,The 只能用 [The] 进行加权,weather 可以用 [The, weather] 进行加权,is 可以用 [The, weather, is] 进行加权,以此类推

Masked Multi-head self-attention

Masked Multi-head self-attention 和 self-attention 的不同就在 attention 矩阵上,在 self-attention 中,我们得到 attention 矩阵之后直接用 softmax 归一化即可,但是在 Masked 中,由于是因果推理,我们会构造一个上三角矩阵,将右上角的元素变为无穷,如图所示,灰色部分表示负无穷,这样 softmax 的时候灰色部分就不会影响计算,从而 the 只能关注到 the,weather 会由 The 和 weather 关注,达到当前元素只能看到他之前元素的效果,然后得到的 attention 矩阵与 v 相乘得到 attention 之后的向量

1.3.2. Feed Forward Network

FFN 比较简单,包含一个 LN 和 MLP,MLP 如下所示,就是将 word embedding 的维度升为原来的 4 倍到 4096,然后降维回 1024 即可

1.4 Transformer 输出

经过上面的 Transformer block,最终的向量输出维度依然等于 hidden state 的维度 [2, 512, 1048],他会经过一个 head 进行输出,head 就是一个线性层,输出为词表的大小 linear(1024, 50257),最终与 GT 计算交叉熵损失

2. GPT-2 的推理过程

2.1 推理输入

在训练过程中,由于 Masked Attention 的存在,训练可以一次训练多句话,例如"The weather is nice day",假设一共分为五个 token,那么 The 对应的输出就应该是 weather,weather 对应的输出就是 is,is 对应的输出就是 nice,以此类推,我们同时可以对所有的词进行预测,即预测一次就相当于把"The",“The weather”,“The weather is”…预测完了,但是在 inference 的时候,我们需要一个词一个词输出预测了,预测完当前词再过一遍模型才能预测下一个词

下面用 GPT-2 的 beam search 方法来举例(源代码用的这个,其实理解了 beam 方法就能弄懂推理过程了)

输入数据,在测试时,输入数据就没有后面的描述了(【】部分),我们只能根据键值对来生成描述

name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil|| 【Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food .】

[3672, 1058, 383, 9212, 930, 5994, 1058, 2240, 930, 2057, 1058, 12549, 2057, 930, 2756, 1058, 1029, 930, 6491, 7955, 1058, 352, 503, 286, 642, 930, 1989, 1058, 18180, 485, 930, 1641, 8030, 1058, 3763, 930, 1474, 1058, 42151, 28799, 17517, 50256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

2.2 第一次推理

这里设置的 beam=10,那么输入就从 [1, 512] 变为 [10, 512],根据词表可以得到输入向量为 [10, 512, 1024],加上位置编码的 [1, 512, 1024] 就可以得到最终的输入了 [10, 512, 1024],这里我们假设问题一共有 42 个 token,也就是说从 43 到 512 个 token 均为 0

首先经过第 transformer block,在 attention 的过程中,能够得到 qkv,这里我们将 k 和 v 都存下来(因为在第二次推理的时候能用到),kv 分别为 [10, 512, 1024],维度变换一下为 [10, 16, 512, 64],其中 16 为 head 的数量,我们将 kv 两个向量 concat 起来得到 [2, 10, 16, 512, 64],注意每一层 block 都要存,这样最后就能保存 24 个 kv 向量,分别对应 transformer block 的 0-23 层

随后 mask 矩阵采用和训练时用的上三角矩阵即可,其余的没有区别正常推理,最终得到向量维度为 [10, 512, 50257],我们将最后一维度通过 softmax,就可以得到每一个 token 分类的概率分布,这里因为是第一次推理,所以 10 个 batch 都是相同的,所以我们取第一个 batch 的结果即可,然后我们获取 top 10,也就是最大的十个概率的对应 Token 编号作为答案的第一个 token,即第 43 个 token

2.3 第二次推理

第二次推理时,我们已经保存了第一次的 24 个 kv,每个 kv 维度为 [2, 10, 16, 512, 64],第二次输入的时候只输入上次预测的 Token,即第 43 个 token,维度为 [10, 1],我们根据词表和位置编码的第 43 个向量将其向量化为 [10, 1, 1024],然后经过 attention 时要变为 [10, 16, 1, 64],那么他是如何进行 attention 的呢?这就用到第一次推理得到的 kv 了,当我们经过第一个 transformer block 时,我们拿出第一次推理时的第一个 block 存的 kv,即 [2, 10, 16, 512, 64],由于当前 token 是第 43 的 token,因此我们将第 43 个 token 进行替换,替换为我们的当前 token 向量 [10, 16, 1, 64],此时新的 kv 会作为新的 cache 存起来为下次推理做准备,新的 kv 也会进行 attention,此时 q 即为输入 [10, 16, 1, 64],kv 分别为 [10, 16, 64, 512],最终得到的 attention 矩阵为 [10, 16, 1, 512],mask 该怎么得到呢?此时从原来的 [512, 512] 变为了 [1, 512],因为当前要预测 44 个 Token 了,所以我们将 43 个 Token 之后的(不包含 43)attention 矩阵中的值变为-inf,即 [1, 512] 中 43 之后的变为-inf,然后正常输出即可,和第一次推理一样

2.4 第 N 次推理

第 N 次推理的过程和第二次完全一样,直到遇到 <eos>token 或者到模型的最大长度为止

4月 06, 2025
3月 10, 2025
12月 31, 2024