LoRA 代码解析
paperreading
本文字数:623 字 | 阅读时长 ≈ 2 min

LoRA 代码解析

paperreading
本文字数:623 字 | 阅读时长 ≈ 2 min

1. LoRA 论文解读

这里先不解读了,直接看 Microsoft 的代码

2. LoRA 代码解析

lora 代码最初由微软GitHub提供,下面就对这套代码来学习 lora。

这里以 NLG 任务中的 e2e 数据集为例

2.1 对数据的预处理

如下图所示,e2e 任务是给定一些独立词汇组成一个句子,input 中的 x 为独立词,y 为组成的句子,0 为填充,将所有的 token 填充到 512 个。对于 Target 由 next token prediction 策略,只需要将所有 token 向右移动一位即可,这样是 511 个 token,然后在后面补 0 即可。Mask 是一个标记,仅标记 y 所在位置。

2.2 Tokenization 处理

再输入到 transformer 中,我们首先对其进行向量化处理,如图所示,我们从词表中 wte(50257, 1024 找到每个 token 对应的词向量,同时从 wpe(1024, 1024) 中找到每个 token 对应的位置向量,然后将两个向量相加,得到最终的输入向量

2.3 Transformer 与 LoRA

接下来是代码中最重要的部分,在 GPT-2 中,这里的模型有 24 个 transformer block,head=16,hidden=1024,其中这里绘制了两个图,其中上面的图是正常的 transformer 中的 attention 操作:LN-Attn,由于上面的部分非常熟悉了,这里就不做介绍了,只对下面的加入 lora 的部分进行介绍。

其中 GPT-2 中的 lora 只作用在 attention 里生成 qkv 的部分,当输入通过 LN 之后,之前的方法是通过一个 Linear 层,但是微调这一层代价太大(如果只微调一个 linear 层还好,但是一共有 24 个 layer,就需要微调 24 个 linear,代价太大),因此我们冻住 Linear 层,在上面加入一个 LoRA 残差网络,如图所示 LoRA 中有一些超参数,例如 r 是秩,这里设为了 4。此外我们只对 QV 进行 lora 更新,并不更新 k,这是因为实验发现微调 QV 的效果更好,因此经过 LoRA 微调后,我们只需要更新 Lora_ALora_B 这两个参数即可,大大降低了优化的难度

经过 attention 之后,我们通过一个 ffn 层即可,这个很简单不做赘述

2.4 GPT-2Head 输出

最终我们通过一个 linear 层进行输出,由于我们的嵌入词表的大小为 50257,因此我们的 linear 层输出也是 50257。

4月 06, 2025
3月 10, 2025
12月 31, 2024