PaperReading:FlashAttention
paperreading
本文字数:2.1k 字 | 阅读时长 ≈ 8 min

PaperReading:FlashAttention

paperreading
本文字数:2.1k 字 | 阅读时长 ≈ 8 min

1. 背景

在现代计算机架构中,内存层次结构是影响系统性能的核心因素之一。从PC到服务器,再到用于AI和高性能计算的GPU加速器,不同层次的内存(如SRAM、HBM、DRAM、Cache)扮演着不同的角色。在介绍 FlashAttention 之前,我们先来了解一下 CPU、GPU、SRAM、HBM、DRAM、L1/L2/L3 Cache 等概念,并解释它们之间的区别,因为 FlashAttention 的优化正是基于这些概念。

1. CPU 与 GPU:计算核心的分工

2. 内存层次结构:SRAM、HBM、DRAM 和 Cache

内存越靠近核心速度越快、容量越小

3. SRAM 和 HBM 的角色

特性 SRAM(片上缓存) HBM(高带宽内存)
位置 核心内部(紧邻计算单元) 核心外部(与计算单元连接)
容量 小(KB~MB) 大(几十GB)
速度 非常快(低延迟) 快但比SRAM慢
用途 缓存计算所需数据块 存储大量输入、中间结果
类比 计算工作台(临时用具) 仓库(大批物品)

2. FlashAttention v1

2.1 问题

以前(没有FlashAttention时)大多数深度学习框架和Transformer实现把计算交给HBM,但实际上,SRAM在GPU内部始终是存在的,但是未利用好。

FlashAttention主动设计了“把分块数据放进SRAM计算”,让SRAM不再只是“被动缓存”,而是主动参与调度和计算。

论文中证明了 FlashAttention 在访问高带宽内存(HBM)时的复杂度是:
$$
O \left( \frac{N^2 d^2}{M} \right)
$$

这里:

这意味着,序列 N 越长,IO 访问次数越多(但比传统方法少很多);片上缓存(SRAM)越大 M,每次能加载更多数据块,访问次数就越少。相比之下,传统标准Attention的HBM访问复杂度是:
$$
\Omega(N d + N^2)
$$
也就是说,标准Attention需要:线性访问Q/K/V($Nd$),存取完整 $N \times N$ 的注意力矩阵($N^2$)。

$O$ 和 $\Omega$ 的含义

1️⃣ $O\bigl(\frac{N^2 d^2}{M}\bigr)$,表示上界,用于描述算法在最坏情况下的增长速度,意为“最多是这个量级”:
- 这里表示 FlashAttention的HBM访问复杂度的渐进上界。
- 意味着当序列长度 $N$、head维度 $d$、SRAM大小 $M$ 越来越大时,HBM的访问次数最多是 $\frac{N^2 d^2}{M}$ 级别。

2️⃣ $\Omega(Nd + N^2)$,表示下界,用于描述算法在最好的情况下的增长速度,意为“至少是这个量级”:
- 这里描述 标准Attention的HBM访问复杂度下界。
- 意味着无论你怎么优化,HBM访问次数至少是 $Nd + N^2$ 级别。

2.2 标准注意力实现

给定输入序列 $Q, K, V \in \mathbb{R}^{N \times d}$,其中 $N$ 是序列长度,$d$ 是head维度,我们希望计算注意力输出 $O \in \mathbb{R}^{N \times d}$:
$$
S = QK^\top \in \mathbb{R}^{N \times N} ~~~
P = \mathrm{softmax}(S) \in \mathbb{R}^{N \times N} ~~~
O = PV \in \mathbb{R}^{N \times d}
$$
标准的注意力实现会将中间矩阵 $S$ 和 $P$ 存储在 HBM(高带宽内存)中,这会占用 ${O}(N^2)$ 的内存。通常 $N \gg d$(例如在 GPT-2 中,$N=1024$, $d=64$)。

由于大部分操作是内存受限(例如 softmax),大量的内存访问会导致运行时间变慢。这个问题在对注意力矩阵应用其他逐元素操作(比如对 $S$ 进行掩码或对 $P$ 进行dropout)时更加严重。

2.3 解决方案

在Transformer的注意力中,Softmax操作对中间矩阵 $S=QK^\top$ 进行归一化。但直接计算完整 $S$ 会占用大量内存。FlashAttention采用 分块(tiling)重计算(recomputation) 技术,避免存储整个矩阵

1. 软最大值(Softmax)计算公式
对于向量 $x \in \mathbb{R}^B$:
$$
m = \max_i x_i ~~~~~ f_i = e^{x_i - m} ~~~~~ l = \sum_i f_i ~~~~~ \mathrm{softmax}(x) = \frac{f}{l}
$$

为了防止数值溢出,先减去最大值 $m$,再进行指数和归一化。

2. 块分解计算

如果将向量分为两块 $x^{(1)}, x^{(2)}$,拼接后 $x = [x^{(1)}, x^{(2)}]$:

合并后的Softmax:
$$
\mathrm{softmax}(x) = {\bigl[e^{x{(1)} - m}, e^{x{(2)} - m} \bigr]}/{l}
$$

这里的 $x{(1)}$ 指的就是 $x^{(1)}$,$x{(2)}$ 指的就是 $x^{(2)}$ (渲染问题)

📚 例子:四维向量Softmax分块计算

假设我们有一个长度为4的向量:
$$
x = [2, 5, 3, 1]
$$

将它分成两块:$x^{(1)} = [2, 5]$,$x^{(2)} = [3, 1]$

1. 计算第一块 $x^{(1)}$

2. 计算第二块 $x^{(2)}$

3. 合并两块

整个 FlashAttention 的算法过程如下:

5月 06, 2025
4月 27, 2025
ufw
4月 06, 2025
ufw