Transformer组件(二):Attention变体
paperreading
本文字数:1.9k 字 | 阅读时长 ≈ 8 min

Transformer组件(二):Attention变体

paperreading
本文字数:1.9k 字 | 阅读时长 ≈ 8 min

在 Transformer 结构中,Attention 机制是核心,下面我们介绍各种 Attention 的变体

假设我们有一个句子:

“我 想 吃 酸 菜 鱼”
假设 每个字的表示是一个 12 维的向量(embedding 维度为 12),那么输入矩阵 $X$ 形状为:

$$
X = shape (6, 12) \quad {(6 个 token,每个 token 12 维)}
$$

其中:

1. Self-Attention

1.1 Self-Attention 计算

  1. 计算 Query, Key, Value
    $$
    Q = X W^Q, \quad K = X W^K, \quad V = X W^V
    $$

    • 其中 $ W^Q, W^K, W^V $ 是 (12, 12) 的可训练矩阵
    • 计算后 $ Q, K, V $ 形状:
      $$
      Q, K, V = {shape} (6, 12)
      $$
  2. 计算 Attention 分数
    $$
    {Attention}(Q, K, V) = {softmax} \left(\frac{QK^T}{\sqrt{d_k}}\right) V
    $$

    • $ QK^T $ 计算所有 token 之间的相似度,形状为 (6,6)
    • 经过 softmax 归一化后,最终得到 Attention 结果,形状仍然是 (6, 12)

1.2 问题

Self-Attention 只有一个 Attention 头,可能无法同时捕捉不同的依赖关系(如短距离 vs. 长距离依赖)。

2. Multi-Head Attention

Multi-Head Attention 的核心思想是:

2.1 切分 Q, K, V

假设我们用 3 个头(head)

2.2 每个头分别计算 Attention

$$
\begin{aligned}
{Attention}_1 & = {softmax} \left(\frac{Q_1 K_1^T}{\sqrt{4}}\right) V_1 \\
{Attention}_2 & = {softmax} \left(\frac{Q_2 K_2^T}{\sqrt{4}}\right) V_2 \\
{Attention}_3 & = {softmax} \left(\frac{Q_3 K_3^T}{\sqrt{4}}\right) V_3
\end{aligned}
$$

每个头的输出形状是 (6, 4)

2.3 拼接不同头的结果

$$
{MultiHead}(Q, K, V) = {Concat}({Attention}_1, {Attention}_2, {Attention}_3)
$$

$$
{shape} (6, 4) + (6, 4) + (6, 4) = (6, 12)
$$

2.4 线性变换整合多个头的信息

$$
{Output} = Concat({Attention}_{1}, {Attention}_2, {Attention}_3) W^O
$$

其中 $ W^O $ 是一个 $ (12, 12) $ 的权重矩阵,负责融合不同头的信息,最终得到 (6, 12) 的输出。

2.5 Self-Attention vs. Multi-Head Attention 总结

Self-Attention Multi-Head Attention
计算方式 计算 1 组 Q, K, V 并计算 Attention 计算多个 Q, K, V 并计算多个 Attention
维度 直接对完整的 12 维向量计算 把 12 维向量切成多个小块,分别计算
关注信息 只能关注一个特征 不同头关注不同特征,提高表达能力
表达能力 可能不足 更丰富,更能学习不同层次信息

下面用一个例子直观解释,假设你在调查一座城市:

3. Grouped Query Attention (GQA)

GQA 的核心思想是:减少 Query 头的数量,但 Key 和 Value 头仍然保持不变。

我们继续以上面的例子,假设句子长度为 6(“我 想 吃 酸 菜 鱼”),每个 token 的向量维度为 12 维(embedding size = 12),标准 MHA 使用 3 个 Attention 头,这里我们将 3 个头改为 6 个头,然后进行下面的实验

假设:

因此,输入矩阵形状:

$$
X = (6, 12) \quad \text{(6 个 token,每个 token 12 维)}
$$

3.1 Multi-Head Attention 计算(6 头)

在标准 Multi-Head Attention(MHA)中,将输入 X 变换为 Q, K, V
$$
Q = X W^Q, \quad K = X W^K, \quad V = X W^V
$$
假设 每个头负责 2 维(即 12 维 / 6 头 = 每个头 2 维),那么:

$$
\begin{aligned}
& Q_1, Q_2, Q_3, Q_4, Q_5, Q_6 \quad (\text{每个 Query 头 shape } = (6, 2)) \\
& K_1, K_2, K_3, K_4, K_5, K_6 \quad (\text{每个 Key 头 shape } = (6, 2)) \\
& V_1, V_2, V_3, V_4, V_5, V_6 \quad (\text{每个 Value 头 shape } = (6, 2))
\end{aligned}
$$

3.2 GQA 计算(3 个 Query 头,6 个 Key-Value 头)

GQA 的核心思想是:减少 Query 头的数量,但 Key 和 Value 头仍然保持 6 头。
1. Query 头减少为 3 头,但 Key 和 Value 仍然是 6 头:

$$
\begin{aligned}
& Q_1, Q_2, Q_3 \quad (\text{每个 Query 头 shape } = (6, 4)) \\
& K_1, K_2, K_3, K_4, K_5, K_6 \quad (\text{每个 Key 头 shape } = (6, 2)) \\
& V_1, V_2, V_3, V_4, V_5, V_6 \quad (\text{每个 Value 头 shape } = (6, 2))
\end{aligned}
$$

由于 Query 头减少,每个 Query 头的维度增加了一倍(6 头 → 3 头,12 维 / 3 = 每个头 4 维)。

Key 和 Value 仍然拆分为 6 头,每个头仍然是 2 维。

  1. Query 头共享 Key-Value 头进行计算:
    • Query 头 1(“我”、“想” 共享):
    $$
    Q_1 = \text{Concat}(X_1, X_2) W^Q_1 \quad (\text{shape } = (6, 4))
    $$
    • Query 头 2(“吃”、“酸” 共享):
    $$
    Q_2 = \text{Concat}(X_3, X_4) W^Q_2 \quad (\text{shape } = (6, 4))
    $$
    • Query 头 3(“菜”、“鱼” 共享):
    $$
    Q_3 = \text{Concat}(X_5, X_6) W^Q_3 \quad (\text{shape } = (6, 4))
    $$

  2. 计算 Attention(共享 Key-Value):
    • 由于 Query 头减少,多个 Query 共享 Key 和 Value:

$$
\begin{aligned}
\text{Attention}1 &= \text{softmax} \left(\frac{Q_1 K{(1,2)}^T}{\sqrt{d_k}}\right) V_{(1,2)} \\
\text{Attention}2 &= \text{softmax} \left(\frac{Q_2 K{(3,4)}^T}{\sqrt{d_k}}\right) V_{(3,4)} \\
\text{Attention}3 &= \text{softmax} \left(\frac{Q_3 K{(5,6)}^T}{\sqrt{d_k}}\right) V_{(5,6)}
\end{aligned}
$$

  1. 最终输出拼接回 12 维:
    $$
    \text{Concat}(\text{Attention}_1, \text{Attention}_2, \text{Attention}_3)
    $$
    $$
    \text{shape} (6, 12)
    $$

3.3 GQA vs MHA 的计算对比

Multi-Head Attention (MHA) Grouped Query Attention (GQA)
Query 头数 6 个
Key 头数 6 个
Value 头数 6 个
每个 Query 头维度 2 维
计算复杂度 O(6 \times d_q d_k)
计算复杂度 O(3 \times d_q d_k)(更快)

MHA:每个 Query 头独立计算 Attention,计算量大。
GQA:减少 Query 头,但 Key 和 Value 头不变,从而减少计算量,提高推理速度,同时保持较强的表达能力。

3.4 直观比喻

Multi-Head Attention(MHA):你找了 6 位助教,每位助教独立评分所有学生。

Grouped Query Attention(GQA):你减少到 3 位助教,但让每位助教参考之前 6 位助教的评分标准(Key 和 Value 仍然 6 组)。