Swin Transformer 详解
paperreading
本文字数:586 字 | 阅读时长 ≈ 2 min

Swin Transformer 详解

paperreading
本文字数:586 字 | 阅读时长 ≈ 2 min

论文地址:https://arxiv.org/pdf/2103.14030.pdf
代码地址:https://github.com/microsoft/Swin-Transformer

本文一共分为三个部分,首先介绍 Swin Transformer 的整体架构,随后会介绍每个模块的作用,中间会穿插部分代码。本文的主要目的还是希望能够将 Swin Transformer 解释清楚,然后结合官方代码来理解

1. Overall Architecture

首先给出论文中的 Swin Transformer 架构图

左边是 Swin Transformer 的全局架构,它包含 Patch Partition、Linear Embedding、Swin Transformer Block、Patch Merging 四大部分,这四大部分我们之后会进行详细的介绍

右边是 Swin Transformer Block 结构图,这是两个连续的 Swin Transformer Block 块,一个是 W-MSA,一个是 SW-MSA,也就是说根据 Swin 的 Tiny 版本,图中的 Swin Transformer Block 块为[2, 2, 6, 2],相对应的 attention 为:stage1 W-MSA-->SW-MSAstage2 W-MSA-->SW-MSAstage3 W-MSA-->SW-MSA-->W-MSA-->SW-MSA-->W-MSA-->SW-MSAstage4 W-MSA-->SW-MSA

2. Swin Transformer

下面的维度等均是基于 Swin-T 版本

2.1 Patch Partition & Linear Embedding

输入为(B, 3, 224, 224)
输出为(B, 96, 56, 56) —> (B, 96, 224/4=56, 224/4=56)

这两步在论文中其实就是一步实现,我们先来看 paper 中的解释:

在真正实现的时候 paper 使用了 PatchEmbed 函数将这两步结合起来,实际上也就是用了一个卷积的操作,卷积核大小为(4, 4),步长为 4:nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

下面图示此过程

2.2 Basic Layer

在官方的代码库中,将 Swin Transformer Block 和 Patch Merging 合并成了一个,叫做 Basic Layer,下面我们分别介绍这两者

Swin Transformer Block

输入为(B, 3136, 96)
输出为(B, 3136, 96)
就是把上一步的(4, 96, 56, 56)后两维度合并变为(4, 96, 3136),然后后两维互换(4, 3136, 96)

Swin Transformer Block 的输入输出不变,每两个连续 Block 为一组,即一个 Window Multi-head Self-Attention 和一个 Shifted Window Multi-head Self-Attention

下面是 paper 中的 Swin Transformer Block 示例图

从图中我们可以看出每两个连续 Block 块有四小步:

1. 第一个 Block

2. 第二个 Block

从上面四步可以看出 Swin Transformer Block 清晰的执行步骤,其中比较难理解的是 W-MSASW-MSA,下面我们详细介绍二者,并介绍由二者引出的一些细节

(1)first block

包含两个主要模块,W-MSA 和 MLP
输入为(B, 3136, 96)
输出为(B, 3136, 96)

W-MSA

window partition

W-MSA 在第一个 block 中,这一步没有滑动窗,输入为(B, 3136, 96),为了后面的 sefl-attention 操作,需要将特征图划分为一个个窗口的形式,首先经历了一个 window partition 操作,变为(64B, 7, 7, 96)

怎么计算的呢?输入为 batch=B,3136=56*56,特征图有 96 个,将每个特征图 56*56 分为 7*7 的窗口,一共能分 8*8=64 个,乘上之前的 B 就是 64B 了,就是说将特征图分为(7, 7)的小窗,然后把所有的小窗拿出来一共有 64B 个,示例图如下

==为什么要进行 window partition?在 Vision Transformer 中,我们将图片分成了一个个 patch(也就是左边的图),在进行 MSA 时,任何一个 patch 都要与其他所有的 patch 都进行 attention,当 patch 的大小固定时,计算量与图片的大小成平方增长。Swin Transformer 中采用了 W-MSA,也就是 window 的形式,不同的 window 包含了相同数量的 patch,只对 window 内部进行 MSA,当图片大小增大时,计算量仅仅是呈线性增加(只增加了图片多余部分的计算量,比如之前是 224 的图像,现在是 256 的图像,只多了 256-224=32 像素的计算部分),下面详细介绍 window attention 部分==

window attention

将窗口分配完成后就可以执行 attention 操作了,首先我们将维度变为(64B, 49, 96),进行 attention 操作时,我们需要 qkv 三个变量,transformer 是通过 linear 函数来实现的:nn.Linear(dim, dim * 3, bias=qkv_bias),通过这个函数后,维度变为(64B, 49, 288),qkv 分别占三分之一,也就是说 qkv 分别为(64B, 49, 96),第一个阶段的 head 为 3,维度划分为(64B, 3, 49, 32)

此时 qkv 的值如下所示,这就是进行 attention 时 qkv 的维度

接下来就是进行 attention 操作,熟悉 transformer 的同学肯定很容易理解

$$
Attention(Q,K,V) = SoftMax(\frac{QK^{T}}{\sqrt{d}}+B)V
$$

注意这里加了一个偏置 B,在最后会详细介绍相对位置偏置(Relative Position Bias)的原理

window reverse

所有 attention 步骤执行完之后就可以回到 attention 之前的维度(64B, 7, 7, 96),然后我们经过一个 window reverse 操作就可以回到 window partition 之前的状态了,即(B, 56, 56, 96)。window reverse 就是 window partition 的逆过程

总结:这里总结一下 W-MSA 所做的事情,首先进行 window partition 操作,维度从(B, 3136, 96)也就是(B, 56, 56, 96)变为(64B, 7, 7, 96);随后进行 attention 操作,先经过一个线性层维度变为三倍来为 qkv 分别赋值(64B, 49, 96*3): qkv(64B, 49, 96),随后根据 multi-head 操作在将 qkv 分别分成三份,(64B, 3, 49, 32),最后进行 attention 操作(即上面的公式),然后通过 window reverse 回到最初的状态(B, 56, 56, 96),也就是(B, 3136, 96),下面图示了这一阶段的过程

MLP

输入为(4, 3136, 96)
输出为(4, 3136, 96)

再经过第二个 Block 之前要先经过一个 MLP,其中结构为

Linear(96, 96*4)——GELU()——Linear(96*4, 96)——Dropout

最终维度并不发生变化

(2)second block

包含两个主要模块,SW-MSA 和 MLP
输入为(4, 3136, 96)
输出为(4, 3136, 96)

与第一个 Block 唯一不同的地方就是 SW-MSA 模块,所以这里仅讲解此模块

SW-MSA

与 W-MSA 不同的地方在于这个模块存在滑动,所以叫做 shifted window,滑动的距离为 win_size//2 在这里也就是 7//2=3,这里用 image(4, 4) win(2, 2) shift=1 来图示他的 shift 以及 mask 机制

这里先给出 Github 上有助于理解此机制的提问:链接

为什么要用 mask 机制呢,Swin Transformer 与 Vision Transformer 相比虽然降低了计算量,但缺点是同一个 window 里面的 patch 可以交互,window 与 window 之间无法交互,所以考虑滑动窗的方法,如上图所示,滑动过后为了保证图片的完整性,我们将上面和左边的图补齐到右边,这又带来了一个缺点:图片的右端和补齐的图片本身并不是相邻的,所以无法交互,解决办法就是 mask

Swin Transformer 的 mask 机制是说,如果相互交互的 patch 属于同一个区域(对应于上图的颜色),那么就可以正常交互,如果不是同一个区域(对应于上图的不同颜色),那么他们交互之后就需要加上一个很大的负值,这样通过 softmax 层之后本来不能交互的那个像素就变成 0 了,这就是 mask 机制

这里附上 Github 上讨论的一个源码,由此可以直接看到 mask 是如何运行的,这个代码与我上述的图是对应的

import torch
import torch.nn as nn


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


window_size = 2
shift_size = 1
H, W = 4, 4
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))

cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.unsqueeze(1).unsqueeze(0)
print(attn_mask)

"""
tensor([[[[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],


         [[[   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.],
           [   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.]]],


         [[[   0.,    0., -100., -100.],
           [   0.,    0., -100., -100.],
           [-100., -100.,    0.,    0.],
           [-100., -100.,    0.,    0.]]],


         [[[   0., -100., -100., -100.],
           [-100.,    0., -100., -100.],
           [-100., -100.,    0., -100.],
           [-100., -100., -100.,    0.]]]]])
"""

2.3 Patch Merging

在每个 Stage 结束的阶段都有一个 Patch Merging 的过程,这个过程会让输入进行降维,同时通道变为原来的二倍,用一个图来清晰的展示此过程,图示如下

上面说到过 Swin 的作用是使得 patch 交互的区域变大,另一种使其变大的方法就是这里提到的 Patch Merging,在每个阶段结束之后,将特征图的维度减半,channel 加倍,在保持 patch 和 window 不变的情况下相当于变相提高了 patch 和 window 的感受野,使其效果更好

到这里 Swin Transformer 的一个 stage 就已经讲完了,其余的 Stage 和上面讲述的完全一致,为了再次强化 Swin Transformer 的整个流程,下面是整个流程展示,其中加粗部分为我们已经走过的流程(这里依然是 Swin-Tiny 版本)

input-->patch partition-->linear embedding
stage1 W-MSA-->MLP-->SW-MSA-->MLP
stage2 W-MSA-->MLP-->SW-MSA-->MLP
stage3 W-MSA-->MLP-->SW-MSA-->MLP *3
stage4 W-MSA-->MLP-->SW-MSA-->MLP-->tail process

3. Supplement

3.1 Relative Position Bias

到这里整个 Swin Transformer 就已经讲完了,还记得 attention 中加了一个 bias B 吗,这里对其进行讲解,依旧取 win=2,如下所示

这里的相对位置偏置这样理解,在窗口中任意选定一个坐标,遵循 左+右-上+下- 的原则,可以发现当我们将左上角的值为 (0, 0) 时,他右边的位置为 (0, -1) 减了 1,下面的位置为 (-1, 0) 也减了 1,同理将其他位置设为 (0, 0) 时,结果分别如图所示

然后我们将其展开,执行:行列分别加 M-1=2-1=1,行标乘 2M-1=3,最终可以得到下图,然后需要注意的是最大值为 8,也就是说一共有 9 个索引,为什么有四个像素,按理来说为 4*4=16 个位置,只有 9 个索引呢?这是因为是相对位置编码位置有重复,又因为 win=2,所以行和列的索引均为 [-1, 1],一共 3*3=9 种组合,即九个相对位置索引,因此相对位置索引表一共有 9 个数字,如下图所示

其中上面是索引表(9 个数),下面是索引后的结果

为了更清晰的认识相对位置偏置,这里给出一个简单的 example

# relative_position_bias_table (1, 9)
relative_position_bias_table = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80, 90])

# relative_position_index (4, 4)
window_size = [2, 2]
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

# index (4, 4)
table = relative_position_bias_table[relative_position_index.view(-1)].view(window_size[0]*window_size[1], window_size[0]*window_size[1], -1)
table = table.permute(2, 0, 1).contiguous().unsqueeze(0)

print("relative_position_index\n", relative_position_index)
print(table)

'''
relative_position_index
 tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])
tensor([[[[50, 40, 20, 10],
          [60, 50, 30, 20],
          [80, 70, 50, 40],
          [90, 80, 60, 50]]]])
'''

到这里 Swin Transformer 就讲完啦,但是因为写的比较仓促有一些地方讲的不够细致,还有关于 FLOPs 运算的细节没有讲到,后面有时间会再补充~

4月 06, 2025
3月 10, 2025
12月 31, 2024