nn.LayerNorm 实现及原理
pytorch
本文字数:2k 字 | 阅读时长 ≈ 10 min

nn.LayerNorm 实现及原理

pytorch
本文字数:2k 字 | 阅读时长 ≈ 10 min

1. nn.LayerNorm 函数

nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)

$$
y = \frac{x-E(x)}{\sqrt{Var(x)+\epsilon}}*\gamma+\beta
$$

2. LayerNorm 在 Transformer 中的应用

在 transformer 中一般采用 LayerNorm,LayerNorm 也是归一化的一种方法,与 BatchNorm 不同的是它是对每单个 batch 进行的归一化,而 batchnorm 是对所有 batch 一起进行归一化的

在 Transformer 中,给定输入 tensor[seq_length, batch_size, d_model],其中 seq_length 为序列长度,batch_size 为 batch 大小,d_model 为 embedding 的维度,如下所示,我们有多个 batch,其中第一个 batch 的序列长度为 3,包括“I, Love, You”,他们的词向量维度均为 6,在 LayerNorm 的时候,分别对“I”,“Love”,“You”向量进行归一化,即相同颜色的数字

2.1 LayerNorm 的官方实现

下面代码展示了 nn.LayerNorm 的官方使用

import torch.nn as nn
import torch

input = torch.arange(1, 19).view(3, 1, 6).type(torch.float32)
print(input)

'''
tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.]],

        [[ 7.,  8.,  9., 10., 11., 12.]],

        [[13., 14., 15., 16., 17., 18.]]])
'''

# 官方nn.LayerNorm实现
norm = nn.LayerNorm(6)
output = norm(input)
print(output)

'''
tensor([[[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]]],
       grad_fn=<NativeLayerNormBackward0>)
'''

# 手动计算验证
mean = torch.mean(input, dim=-1, keepdim=True)
# 这里方差计算是除了N,不是N-1
std = torch.std(input, correction=0, dim=-1, keepdim=True)
output = (input - mean) / (std + 1e-5)
print(output)

'''
tensor([[[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]]])
'''

2.2 LayerNorm 的自定义实现

下面代码展示了 nn.LayerNorm 的自定义实现

import torch.nn as nn
import torch

class LayerNorm(nn.Module):
    "Construct a layernorm module."

    def __init__(self, features, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, correction=0, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


input = torch.arange(1, 19).view(3, 1, 6).type(torch.float32)
print(input)
'''
tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.]],

        [[ 7.,  8.,  9., 10., 11., 12.]],

        [[13., 14., 15., 16., 17., 18.]]])
'''

norm = LayerNorm(6)
output = norm(input)
print(output)
'''
tensor([[[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]]],
       grad_fn=<AddBackward0>)
'''

mean = torch.mean(input, dim=-1, keepdim=True)
std = torch.std(input, correction=0, dim=-1, keepdim=True)
output = (input - mean) / (std + 1e-5)
print(output)
'''
tensor([[[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]],

        [[-1.4638, -0.8783, -0.2928,  0.2928,  0.8783,  1.4638]]])
'''

3. nn.LayerNorm 的底层逻辑

根据上述例子我们可以了解到 LayerNorm 的归一化底层逻辑,给定 LayerNorm 归一化维度,他会将输入 tensor 的最后几个维度进行整体归一化。什么意思呢?假设我们的输入为(1, 3, 4, 4)的变量,并对其进行 LayerNorm,这里我们展示两个例子

注意:这里的例子只是帮助理解 LayerNorm 函数的用法,并不是说四维 tensor 就要按照下面两种方式处理,正常来说,CNN 中很少用 LayerNorm

如下图所示,左边为第一种归一化方法,对所有 channel 所有像素计算;右边为第二种归一化方法,对所有 channel 的每个像素分别计算

3.1 第一种计算

直接给出计算代码

注意:输入为(1, 3, 4, 4),layernorm 的 normalized_shape 为[3, 5, 5],也就是说对后三维度进行归一化操作

import torch.nn as nn
import torch


input = torch.arange(1, 49).view(3, 4, 4).type(torch.float32)
input = input.unsqueeze(0)  
'''
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[17., 18., 19., 20.],
          [21., 22., 23., 24.],
          [25., 26., 27., 28.],
          [29., 30., 31., 32.]],

         [[33., 34., 35., 36.],
          [37., 38., 39., 40.],
          [41., 42., 43., 44.],
          [45., 46., 47., 48.]]]])
'''


# 直接使用nn.LayerNorm函数计算
norm = nn.LayerNorm([3, 4, 4])
print(norm(input))
'''
tensor([[[[-1.6963, -1.6242, -1.5520, -1.4798],
          [-1.4076, -1.3354, -1.2632, -1.1910],
          [-1.1189, -1.0467, -0.9745, -0.9023],
          [-0.8301, -0.7579, -0.6858, -0.6136]],

         [[-0.5414, -0.4692, -0.3970, -0.3248],
          [-0.2526, -0.1805, -0.1083, -0.0361],
          [ 0.0361,  0.1083,  0.1805,  0.2526],
          [ 0.3248,  0.3970,  0.4692,  0.5414]],

         [[ 0.6136,  0.6858,  0.7579,  0.8301],
          [ 0.9023,  0.9745,  1.0467,  1.1189],
          [ 1.1910,  1.2632,  1.3354,  1.4076],
          [ 1.4798,  1.5520,  1.6241,  1.6963]]]],
       grad_fn=<NativeLayerNormBackward0>)
'''

# 手动计算
mean = torch.mean(input)
std = torch.std(input, correction=0)
x = (input-mean)/(std+1e-5)
print(x)
'''
tensor([[[[-1.6963, -1.6241, -1.5520, -1.4798],
          [-1.4076, -1.3354, -1.2632, -1.1910],
          [-1.1189, -1.0467, -0.9745, -0.9023],
          [-0.8301, -0.7579, -0.6858, -0.6136]],

         [[-0.5414, -0.4692, -0.3970, -0.3248],
          [-0.2526, -0.1805, -0.1083, -0.0361],
          [ 0.0361,  0.1083,  0.1805,  0.2526],
          [ 0.3248,  0.3970,  0.4692,  0.5414]],

         [[ 0.6136,  0.6858,  0.7579,  0.8301],
          [ 0.9023,  0.9745,  1.0467,  1.1189],
          [ 1.1910,  1.2632,  1.3354,  1.4076],
          [ 1.4798,  1.5520,  1.6241,  1.6963]]]])
'''

当然如果要灵活的进行操作,可以将 tensor 提前 resize 以下,这样 LayerNorm 就不需要传入 list 列表了,比如这里将输入 resize 为 [1, 3*4*4],这样初始化 LayerNorm(3*4*4) 即可,等操作完成后再 resize 回来

3.2 第二种计算

直接给出计算代码

注意:我们的输入是(1, 3, 4, 4),如果要完成第二种方法,我们 layernorm 只需要提供一个参数,即 norm = nn.LayerNorm(3),但是如果只提供一个参数,默认为对最后一维进行归一化,所以我们需要将输入进行变化,即变为(1, 4, 4, 3)。

import torch.nn as nn
import torch


input = torch.arange(1, 49).view(3, 4, 4).type(torch.float32)
input = input.unsqueeze(0)  # [1, 3, 5, 5]

# [1, 3, 5, 5] -> [1, 5, 5, 3]
input = input.permute(0, 2, 3, 1).contiguous()
print(input)       # [1, 5, 5, 3]
'''
tensor([[[[ 1., 17., 33.],
          [ 2., 18., 34.],
          [ 3., 19., 35.],
          [ 4., 20., 36.]],

         [[ 5., 21., 37.],
          [ 6., 22., 38.],
          [ 7., 23., 39.],
          [ 8., 24., 40.]],

         [[ 9., 25., 41.],
          [10., 26., 42.],
          [11., 27., 43.],
          [12., 28., 44.]],

         [[13., 29., 45.],
          [14., 30., 46.],
          [15., 31., 47.],
          [16., 32., 48.]]]])
'''

# LayerNorm函数计算
norm = nn.LayerNorm(3)
print(norm(input))
'''
tensor([[[[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]],

         [[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]],

         [[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]],

         [[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]]]], grad_fn=<NativeLayerNormBackward0>)
'''


# 手动计算
mean = torch.mean(input, dim=3, keepdim=True)
std = torch.std(input, correction=0, dim=3, keepdim=True)
x = (input-mean)/(std+1e-5)
print(x)
'''
tensor([[[[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]],

         [[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]],

         [[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]],

         [[-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247],
          [-1.2247,  0.0000,  1.2247]]]])
'''

# 最后将输出resize回输入维度 [1, 5, 5, 3] -> [1, 3, 5, 5]
x = x.permute(0, 3, 1, 2)
print(x) # [1, 3, 5, 5]
'''
tensor([[[[-1.2247, -1.2247, -1.2247, -1.2247],
          [-1.2247, -1.2247, -1.2247, -1.2247],
          [-1.2247, -1.2247, -1.2247, -1.2247],
          [-1.2247, -1.2247, -1.2247, -1.2247]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 1.2247,  1.2247,  1.2247,  1.2247],
          [ 1.2247,  1.2247,  1.2247,  1.2247],
          [ 1.2247,  1.2247,  1.2247,  1.2247],
          [ 1.2247,  1.2247,  1.2247,  1.2247]]]])
'''
4月 06, 2025
3月 10, 2025
12月 31, 2024