关于 ViT 中 pos embedding 的可视化
paperreading
本文字数:682 字 | 阅读时长 ≈ 3 min

关于 ViT 中 pos embedding 的可视化

paperreading
本文字数:682 字 | 阅读时长 ≈ 3 min

在 ViT 中有一个 position embedding 部分,为什么要有这一部分呢?

在 NLP 中,不同词转化为 Token 之后有一个位置编码的模块,这是因为不同词汇之间是有顺序的,但是在视觉领域,图像与图像之间是没有顺序的,ViT 将每一幅图划分为一个个 patch,如下图所示,每一个 patch 就对应于 NLP 中的一个 Token,而且从图中也可以直观的感受到每一个 patch 都是有位置的,所以在每一个特征维度上都加入了一个 position embedding 模块,最后我们可视化一下 Google 预训练后 position embedding 的结果

可视化左上角的 patch

可视化所有 patch

这是一幅图中所有 patch 可视化的结果,但是因为 patch 太多,不是很清晰,但是还是可以看出大体的位置效果

注意
假设我们的 Patch 一共是 576 个,那么计算出来的每一个可视化图都是 576 维也就是$24\times 24$,每一维度都是计算余弦相似度。以左上角的第一幅图为例,先计算第一维与自己的余弦相似度,在计算他与其他 575 维的余弦相似度,最后得到 576 个值,reshape 为$(24, 24)$可视化即可,通过可视化结果我们可以发现他与自己的相似度最高,与他同行或者同列的相似度次之,其余的相似度最小。下面是余弦相似度的计算公式

$$
similarity = cos(\theta) = \frac{A\cdot B}{||A||||B||} = \frac{\sum_{i=1}{n}A_{i}B_{i}}{\sqrt{\sum_{i=1}{n}(A_{i})2}\sqrt{\sum_{i=1}{n}(B_{i})^2}}
$$

下面直接给出代码

首先下载预训练模型,链接,然后放到 py 相同的文件夹下运行即可

# show position embedding picture
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


def bit_product_sum(x, y):
    return sum([item[0] * item[1] for item in zip(x, y)])


def cosine_similarity(x, y, norm=False):
    """ 计算两个向量x和y的余弦相似度 """
    assert len(x) == len(y), "len(x) != len(y)"

    xy = x.dot(y)
    x2y2 = np.linalg.norm(x, ord=2) * np.linalg.norm(x, ord=2) 
    sim = xy/x2y2
    return sim


data = np.load("imagenet21k+imagenet2012_ViT-B_16.npz")
pos = data['Transformer/posembed_input/pos_embedding'].reshape(577, 768)[1:, :]  # 576, 768

cos = np.zeros((576, 576))
# 只计算左上角的值
# for i in tqdm(range(1)):
#     for j in range(576):
#         cos[i, j] = cosine_similarity(pos[i, :], pos[j, :])
# cos = cos[0, :].reshape(24, 24)
# plt.imshow(cos)
# plt.show()

# 计算所有
for i in tqdm(range(576)):
    for j in range(576):
            cos[i, j] = cosine_similarity(pos[i, :], pos[j, :])


fig, axs = plt.subplots(nrows=24, ncols=24, figsize=(24, 24),
                subplot_kw={'xticks': [], 'yticks': []})
i=0
cos = cos.reshape(576, 24, 24)
for ax in axs.flat:
    ax.imshow(cos[i, :, :],  cmap='viridis')
    i+=1
plt.tight_layout()
plt.show()