MultiHeadAttention 函数的详细说明
paperreading
本文字数:1.2k 字 | 阅读时长 ≈ 6 min

MultiHeadAttention 函数的详细说明

paperreading
本文字数:1.2k 字 | 阅读时长 ≈ 6 min

之前详细讲解过 transformer 中的 attention 机制,这里在使用 MultiHeadAttention 函数时发现 pytorch 已经实现了这个库,但是文档对此函数并没有很好的解释,这里从底层逻辑剖析一下此函数的使用

1. 自定义 MultiHeadAttention 函数

在介绍 pytorch 官方库之前,我们先手写一个相同的函数大概了解一下其运行逻辑

这里我们以 NLP 的输入为例,输入为 [seq_length, batch_size, d_model],即第一维是词的个数,第二维是 batch,第三维度是每个词的维度。例如我们有一个句子 I love you,那么其输入为 [3, 1, 6],即三个词,每个词的维度为 6。对于 CV 也是相同的道理,第一维度是图像 patch 的数量,第二维度是 batch,第三维度是每个 patch 的维度。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

class MultiHeadAttention(nn.Module):
    """
    Multi head attention module

    Args:
        h_head: number of heads
        d_model: dimension of input
        dropout: dropout rate
    """
    def __init__(self, h_head, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h_head == 0
        self.h_head = h_head
        self.d_k = d_model // h_head

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.fc = nn.Linear(d_model, d_model)

        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h_head heads.
            mask = mask.unsqueeze(1)
        batch_size = query.size(1)
        
        print ('Before transform query: ' + str(query.size())) # (batch_size, seq_length, d_model)        

        query = self.wq(query).view(batch_size, -1, self.h_head, self.d_k).transpose(1, 2)
        key   = self.wk(key).view(batch_size, -1, self.h_head, self.d_k).transpose(1, 2)
        value = self.wv(value).view(batch_size, -1, self.h_head, self.d_k).transpose(1, 2)
                
        print ('After transform query: ' + str(query.size()))
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h_head * self.d_k)
        x = self.fc(x)
        return x, self.attn

n_head = 2
d_model = 6
batch_size = 1
seq_length = 3
model = MultiHeadAttention(n_head, d_model)

query = torch.randn([seq_length, batch_size, d_model])  # [3, 1, 6]
key = query
value = query
print ('Input size: ' + str(query.size()))
output, attn = model(query, key, value)
print ('Output size: ' + str(output.size()))

"""
Input size: torch.Size([3, 1, 6])
Before transform query: torch.Size([3, 1, 6])
After transform query: torch.Size([1, 2, 3, 3])
Output size: torch.Size([1, 3, 6])
"""

上述代码中,MultiHeadAttention 一共有四个参数,三个 Linear 获取 query, key, value,最后一个 Linear 将 attention 后的结果映射到与输入相同的维度。这里我们的输入以及 qkv 的维度都相同,所以 Linear 只是增加了增加了网络的复杂度

2. nn.MultiHeadAttention 函数

在了解 MultiHeadAttention 函数的底层逻辑之后,我们再来看看 pytorch 官方库是如何实现的

首先给出官方文档的解释

import torch
import torch.nn as nn
import numpy as np


n_head = 2
d_model = 6
batch_size = 1
seq_length = 3

attention = nn.MultiheadAttention(d_model, n_head)
print(attention.in_proj_weight.size())
print(attention.in_proj_bias.size())
print(attention.out_proj.weight.size())
print(attention.out_proj.bias.size())

'''
torch.Size([18, 6])
torch.Size([18])
torch.Size([6, 6])
torch.Size([6])
'''

# 修改multiheadattention中的in_proj参数
wq = torch.Tensor(np.ones((6, 6)))
wk = torch.Tensor(np.ones((6, 6))) * 2
wv = torch.Tensor(np.ones((6, 6))) * 3
weight = torch.nn.Parameter(torch.concatenate([wq, wk, wv], dim=0))
attention.in_proj_weight.data = weight
attention.in_proj_bias.data = torch.nn.Parameter(torch.Tensor(np.zeros((18,))))

# 修改multiheadattention中的out_proj参数
fc_weight = torch.nn.Parameter(torch.Tensor(np.ones((6, 6))))
fc_bias = torch.nn.Parameter(torch.Tensor(np.zeros((6,))))
attention.out_proj.weight = fc_weight
attention.out_proj.bias = fc_bias


# 定义输入
x = torch.ones([seq_length, batch_size, d_model])

output, attn = attention(x, x, x, average_attn_weights=False)  # 默认为True
print(output)
print(output.size())

print(attn)
print(attn.size())

'''
# output
tensor([[[108., 108., 108., 108., 108., 108.]],

        [[108., 108., 108., 108., 108., 108.]],

        [[108., 108., 108., 108., 108., 108.]]], grad_fn=<ViewBackward0>)
torch.Size([3, 1, 6])

# attention
tensor([[[[0.3333, 0.3333, 0.3333],
          [0.3333, 0.3333, 0.3333],
          [0.3333, 0.3333, 0.3333]],

         [[0.3333, 0.3333, 0.3333],
          [0.3333, 0.3333, 0.3333],
          [0.3333, 0.3333, 0.3333]]]], grad_fn=<ViewBackward0>)
torch.Size([1, 2, 3, 3])
'''

首先 nn.MultiHeadAttention 中有四个参数,in_proj_weightin_proj_biasout_proj.weightout_proj.bias,其中前两个参数是用来获取 query, key, value 的,后两个参数是在 attention 之后再通过线性层将 output 映射到与输入相同维度的,这几点和第一部分完全相同,可以对照理解,然后我们讲解上述代码做了什么事情

  1. 修改了 in_proj_weightin_proj_bias 参数,我们将 query, key, value 的转换矩阵分别全部初始化为 1,2,3,同时将 bias 全部初始化为 0
  2. 修改了 out_proj.weightout_proj.bias 参数,我们将 output 的转换矩阵全部初始化为 1,bias 全部初始化为 0,表示在 attention 之后的结果如果和输入维度不一样,会将其映射到和输入维度相同的维度

其中第一步参考下图理解

第二步参考下图理解,在第一步结束之后,把不同的 head concat 起来,然后通过线性层映射到和输入维度相同的维度

从上面对比可以看出 nn.MultiHeadAttention 并没有进行 LayerNorm 操作,只是进行多头注意力机制