RoPE 位置编码
paperreading
本文字数:2.3k 字 | 阅读时长 ≈ 11 min

RoPE 位置编码

paperreading
本文字数:2.3k 字 | 阅读时长 ≈ 11 min

旋转位置编码(简称 RoPE)是一种用于 Transformer 类模型的相对位置编码方法,在像 Llama 等模型中被广泛应用。

1. 1D RoPE

1D 就是为“一维序列(如文本)”设计的 Rotary Position Embedding。即每个token 都有一个对应的位置索引(pos),然后将这个位置嵌入方式“旋转”地作用于每个特征维度的表示。

RoPE 利用复数相乘等效于向量旋转,在 embedding 空间中按“角度”旋转来表达位置。因此位置变化 → token 表示旋转,天然体现相对位置。

以 Transformer 的 Q/K/V 为例,每个 token embedding 是 $d$ 维(一般是偶数,或能拆成偶数组)。首先构建频率序列,对每个组(一般一组2维,对应偶数$i$和$i+1$),定义:
$$
\theta_j = 10000^{\frac{-2j}{d}}, \quad j=0,1,…,\frac{d}{2}-1
$$

对每个位置 $p$,定义旋转角 $p \cdot \theta_j$,对于每组 embedding $(x_{2j}, x_{2j+1})$,做二维旋转
$$
\begin{bmatrix}x’_{2j} \\ x’_{2j+1} \end{bmatrix} =
\begin{bmatrix}
\cos(p \cdot \theta_j) & -\sin(p \cdot \theta_j) \\
\sin(p \cdot \theta_j) & \cos(p \cdot \theta_j)
\end{bmatrix}
\begin{bmatrix} x_{2j} \\ x_{2j+1}\end{bmatrix}
$$

即可看作把 embedding 的每对分量作为二维向量,在平面上转相应的角度。

复数表达

$(x_{2j}, x_{2j+1})$ 可合成 $z_j = x_{2j} + i x_{2j+1}$,那么旋转就是:$z_j’ = z_j \cdot e^{i(p \theta_j)}$

假设你有二维点 $(a, b)$,把它当作复数 $z = a + ib$。
如果用模为1的复数 $e^{i\phi} = \cos\phi + i\sin\phi$ 去乘 $z$,就会把点$(a,b)$绕原点旋转$\phi$角度。

$$
z^\prime = z \cdot e^{i\phi} = (a + ib)(\cos\phi + i\sin\phi) \\
= a\cos\phi - b\sin\phi + i(a\sin\phi + b\cos\phi)
$$

这正好就是二维旋转矩阵

下面举个例子说明一下

import torch
import math
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

# 中文字体设置(macOS)
matplotlib.rcParams['font.family'] = "Heiti TC"
matplotlib.rcParams['axes.unicode_minus'] = False

def apply_rope(x: torch.Tensor, position: int):
    D = x.shape[0]
    assert D % 2 == 0
    x_roped = x.clone()
    for i in range(0, D, 2):
        freq_idx = i // 2
        theta = position / (10000 ** (2 * freq_idx / D))
        cos_theta = math.cos(theta)
        sin_theta = math.sin(theta)
        xi, xi1 = x[i], x[i + 1]
        x_roped[i]     = cos_theta * xi - sin_theta * xi1
        x_roped[i + 1] = sin_theta * xi + cos_theta * xi1
    return x_roped

# 原始 token 向量
# x = torch.arange(12).float()
x = torch.ones(12).float() * 5
positions = [0, 1, 2, 5, 10]
roped_vectors = [apply_rope(x, pos) for pos in positions]

# 展示所有偶数维度对
D = x.shape[0]
num_pairs = D // 2
nrows, ncols = 2, 3
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 5 * nrows))
axes = axes.flatten()
colors = cm.viridis(np.linspace(0, 1, len(positions)))

for plot_idx in range(num_pairs):
    i = plot_idx * 2
    ax = axes[plot_idx]
    ax.set_title(f"维度对 ({i},{i+1}) 的旋转效果")
    ax.set_xlim(-15, 15)
    ax.set_ylim(-15, 15)
    ax.axhline(0, color='gray', lw=0.5)
    ax.axvline(0, color='gray', lw=0.5)
    ax.set_aspect('equal')

    handles = []

    # 原始向量
    xi, xi1 = x[i], x[i + 1]
    q = ax.quiver(0, 0, xi.item(), xi1.item(), angles='xy', scale_units='xy', scale=1,
                  color='black', width=0.015, label='原始向量', zorder=10)
    handles.append(q)

    # 各位置旋转向量
    for j, pos in enumerate(positions):
        xi_r, xi1_r = roped_vectors[j][i], roped_vectors[j][i + 1]
        q = ax.quiver(0, 0, xi_r.item(), xi1_r.item(), angles='xy', scale_units='xy', scale=1,
                      color=colors[j], width=0.01, label=f'pos={pos}')
        handles.append(q)

    ax.legend(handles=handles, loc='upper left')

# 隐藏多余子图
for idx in range(num_pairs, nrows * ncols):
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('rope.png', dpi=300)  # 保存为 PNG 文件
plt.show()

可视化结果

这是呈现的结果,在RoPE中,theta的计算方式如下$\theta_j = \frac{\text{position}}{10000^{2j/D}}$

采用“指数衰减”的原因?

将频率(角度步长)做成指数衰减,可以确保:既可以编码“很短距离的”相对位置信息(如某词的邻居是谁)——靠低维做细致区分,也能反映“很长距离的”全局顺序信息——靠高维做粗略编码,这与原始Transformer的正弦位置编码里的“多尺度频率”思路与Motivation是一样的!

为什么同一 pos 的 embedding 在不同维度上旋转角度不同?

因为每个维度的旋转角度 $\theta_j$ 是根据该维度的频率 $j$ 计算的。低频维度(小 $j$)的旋转角度较大,能更细致地捕捉位置变化;高频维度(大 $j$)的旋转角度较小,能更平滑地捕捉全局顺序信息。
这种设计使得模型能够同时感知局部和全局的位置信息。如果所有维度同角度旋转,RoPE的“不同空间编码多尺度位置关系”的能力就没了,模型只能获得极弱的位置感知力,信息容量极低,也失去跨序列长度泛化的意义,这就是必须多频率的本质原因!

2. 2D RoPE

Transformer 在处理图像时,每个 patch/pixel 有二维空间坐标 $(pos_x, pos_y)$。1D RoPE 编码一维序列的位置,而 2D RoPE 则致力于高效表示 二维网格上每个点的空间位置信息

核心思想是将 embedding 按 x, y 两个方向各自“分管一半”,分别用 pos_xpos_y 采用 1D RoPE 的方法进行独立旋转。

设 embedding 维度为 $D$, 分成前后各一半:

每半仍按一维 RoPE 那样,分成若干组二维向量(每组为 $(x_{i}, x_{i+1})$)各自以多尺度频率旋转。对 $D$ 维 embedding $x$,$pos_x, pos_y$ 为横纵位置,有:

前半:$i = 0,2,4,\dots, D/2-2$
$$
\theta^{(x)}_j = \frac{pos_x}{10000^{\frac{2j}/{(D/2)}}}
$$

$$
\begin{bmatrix} x’_{i} \\ x’_{i+1} \end{bmatrix} =
\begin{bmatrix}
\cos \theta_j^{(x)} & -\sin \theta_j^{(x)} \\
\sin \theta_j^{(x)} & \cos \theta_j^{(x)}
\end{bmatrix}
\begin{bmatrix} x_{i} \\ x_{i+1} \end{bmatrix}
$$

其中 $j = i/2$,后半:$i = D/2, D/2+2, …, D-2$
$$
\theta^{(y)}_j = \frac{pos_y}{10000^{\frac{2j}/{(D/2)}}}
$$

$$
\begin{bmatrix} x’_{i} \\ x’_{i+1} \end{bmatrix} =
\begin{bmatrix}
\cos \theta_j^{(y)} & -\sin \theta_j^{(y)} \\
\sin \theta_j^{(y)} & \cos \theta_j^{(y)}
\end{bmatrix}
\begin{bmatrix} x_{i} \\ x_{i+1} \end{bmatrix}
$$

其中 $j = (i - D/2)/2$,这里每一半仍然有自己的多尺度频率,依赖于二维坐标的不同分部。

复数形式

类比 1D RoPE:每个二维小组可用 $z^{(x)}_j = x_{2j} + i x_{2j+1}$,旋转为 $z^{(x)} _{j} \cdot e^{i\theta _{j}{(x)}}$,$z{(y)}_j = x_{D/2+2j} + i x_{D/2+2j+1}$,旋转 $z^{(y)} _{j} \cdot e^{i\theta _{j}^{(y)}}$。

import torch
import math
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

# 中文字体设置(macOS,可注释掉)
matplotlib.rcParams['font.family'] = "Heiti TC"
matplotlib.rcParams['axes.unicode_minus'] = False

def apply_2d_rope(x: torch.Tensor, pos_x: int, pos_y: int):
    """
    x: [D], D 为 4 的倍数
    pos_x, pos_y: 坐标
    前一半用 pos_x, 后一半用 pos_y
    """
    D = x.shape[0]
    assert D % 4 == 0, "embedding dim must be a mul of 4"
    x_roped = x.clone()
    half_D = D // 2
    # 前半: 用pos_x
    for i in range(0, half_D, 2):
        freq_idx = i // 2
        theta = pos_x / (10000 ** (2 * freq_idx / half_D))
        cos_theta = math.cos(theta)
        sin_theta = math.sin(theta)
        xi, xi1 = x[i], x[i + 1]
        x_roped[i]     = cos_theta * xi - sin_theta * xi1
        x_roped[i + 1] = sin_theta * xi + cos_theta * xi1
    # 后半: 用pos_y
    for i in range(half_D, D, 2):
        freq_idx = (i - half_D) // 2
        theta = pos_y / (10000 ** (2 * freq_idx / half_D))
        cos_theta = math.cos(theta)
        sin_theta = math.sin(theta)
        xi, xi1 = x[i], x[i + 1]
        x_roped[i]     = cos_theta * xi - sin_theta * xi1
        x_roped[i + 1] = sin_theta * xi + cos_theta * xi1
    return x_roped

# 设置 embedding
x = torch.ones(12).float() * 4
# 测试的 (pos_x, pos_y) 组合
positions = [(0, 0), (2, 0), (0, 2), (2, 2), (4, 4), (8, 4)]
roped_vectors = [apply_2d_rope(x, px, py) for px, py in positions]

# 展示所有维度对
D = x.shape[0]
num_pairs = D // 2
nrows, ncols = 2, 3     # 可根据 embedding 更大可增大
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 5 * nrows))
axes = axes.flatten()
colors = cm.viridis(np.linspace(0, 1, len(positions)))

for plot_idx in range(num_pairs):
    i = plot_idx * 2
    ax = axes[plot_idx]
    if i < D // 2:
        title = f"维度({i},{i+1}) (受 pos_x 影响)"
    else:
        title = f"维度({i},{i+1}) (受 pos_y 影响)"
    ax.set_title(title)
    ax.set_xlim(-15, 15)
    ax.set_ylim(-15, 15)
    ax.axhline(0, color='gray', lw=0.5)
    ax.axvline(0, color='gray', lw=0.5)
    ax.set_aspect('equal')

    handles = []

    # 原始向量
    xi, xi1 = x[i], x[i + 1]
    q = ax.quiver(0, 0, xi.item(), xi1.item(), angles='xy', scale_units='xy', scale=1,
                  color='black', width=0.015, label='原始向量', zorder=10)
    handles.append(q)

    # 各 (pos_x, pos_y) 旋转后的向量
    for j, (px, py) in enumerate(positions):
        xi_r, xi1_r = roped_vectors[j][i], roped_vectors[j][i + 1]
        pos_label = f'({px},{py})'
        q = ax.quiver(0, 0, xi_r.item(), xi1_r.item(), angles='xy', scale_units='xy', scale=1,
                      color=colors[j], width=0.01, label=pos_label)
        handles.append(q)

    ax.legend(handles=handles, loc='upper left', fontsize=8)

# 隐藏多余子图
for idx in range(num_pairs, nrows * ncols):
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig('2d_rope.png', dpi=300)
plt.show()

可视化结果

图中每一子图展示embedding一个二维对(如第0/1维),不同颜色箭头代表不同(x, y)格的旋转结果。

和1D RoPE的区别/类比

1D RoPE 2D RoPE
输入数据 一维序列 二维网格
坐标功能 单坐标pos 横纵坐标 $(pos_x,pos_y)$
旋转规则 全部分组用pos 一半用pos_x,另一半用pos_y
能力 顺序相关(文本/序列) 空间相关(图像/表格/patch)