nn.BatchNorm
pytorch
本文字数:2.7k 字 | 阅读时长 ≈ 12 min

nn.BatchNorm

pytorch
本文字数:2.7k 字 | 阅读时长 ≈ 12 min

本篇博客主要讲解 BatchNorm 函数的执行过程,需要读者有一定的批归一化的基础,本文例子通俗易懂,如果没有基础也可以阅读

PyTorchBatchNorm 有三个函数,这里主要讲解前两个,后面的就很容易理解,首先要明白批归一化是做什么的:BatchNorm 在深度网络中用来加速神经网络的训练,能够加速收敛并且可以使用较大的学习率,同时归一化还有一定的正则作用

torch.nn.BatchNorm1d

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

归一化的结果如下所示:
y=xE[x]Var[x]+ϵγ+β
其中E[x]为样本某一维的均值,Var[x]为样本某一维的方差,注意这里计算方差的时候使用的是==有偏估计==,对应的函数为torch.var(input, unbiased=False),什么意思呢,就是说计算方差的时候分母为N而不是N1,即Var[x]=1NNi=1(xi¯x)

下面通过一个例子来详细介绍各个参数的作用

初始化

下面我们将nn.BatchNorm1d(2, affine=False)中的affine有的设为了True,有的设为了False,这里默认为True,就是说对于上面公式中的γβ我们是要学习的,在BatchNorm中他们的参数分别为batch.weight以及batch.bias,默认为1和0,之后通过学习反向传播时会发生变化

import torch
import numpy as np
import torch.nn as nn


torch.manual_seed(1)
m = nn.BatchNorm1d(2)  # With Learnable Parameters
print('m:', m)
n = nn.BatchNorm1d(2, affine=False)  # Without Learnable Parameters
print('n:', n)
input = torch.randn(3, 2)
print('input:', input)
'''
m: BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
n: BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
input: tensor([[-0.2293, -1.4997],
        [-0.7896, -1.0517],
        [-1.0352,  0.9538]])
'''

求上述数据的均值和方差

我们初始化的input维度为(3, 2),也就是说我们的batch取了3,每一个数据有两个特征,下面我们用torchnumpy来计算input的均值和方差,从中我们可以看出numpy是用==有偏估计==来计算方差的

最后我们手动计算均值和方差来进行一下验证,可以看到均值的计算是正确的,方差仅计算了第一列(也就是第一个特征)也是正确的,最后输出的y就是归一化的输出值[1.25506570.063079371.1919862]T这里加上转置,因为是第一列归一化的结果,所以是列排列,这个y在这一节不重要,我们在下面会再次提到

# tensor求均值和方差
torch_mean = torch.mean(input, dim=0)
torch_var_biased = torch.var(input, dim=0, unbiased=False)
torch_var = torch.var(input, dim=0)
print('torch_mean:------------', torch_mean)
print('torch_var:-------------', torch_var)
print('torch_var_biased:------', torch_var_biased)
print()
'''
torch_mean:------------ tensor([0.0904, 0.2407])
torch_var:------------- tensor([0.3105, 0.1555])
torch_var_biased:------ tensor([0.2070, 0.1037])
'''


# numpy求均值和方差
input_numpy = input.numpy()
numpy_mean = np.mean(input_numpy, axis=0)
numpy_var = np.var(input_numpy, axis=0)
print('numpy_mean:------------', numpy_mean)
print('numpy_var:-------------', numpy_var)
print()
'''
numpy_mean:------------ [0.09037449 0.24070372]
numpy_var:------------- [0.20696904 0.10368937]
'''


# manual
man_mean = np.sum(input_numpy, axis=0)/3
print('man_mean:--------------', man_mean)
temp = input_numpy[:,0]-man_mean[0]
print('temp:------------------', temp)
man_var = np.sum(np.power(temp, 2))/3
print(man_var)
y = (input_numpy[:,0]-man_mean[0])/(np.sqrt(man_var))
print('y:---------------------', y)
print()
'''
man_mean:-------------- [0.09037449 0.24070372]
temp:------------------ [ 0.5709777  -0.02869723 -0.54228044]
0.20696904261906943
y:--------------------- [ 1.2550657  -0.06307937 -1.1919862 ]
'''

归一化

下面验证归一化的结果,发现torch.nn.BatchNorm1d的输出和我们手动计算的==有偏估计==方差的结果是一样的

# Batchnorm
output = m(input) # 列归一化
print('output:', output)
'''
output: tensor([[ 1.2550,  0.0814],
        [-0.0631,  1.1819],
        [-1.1920, -1.2634]], grad_fn=<NativeBatchNormBackward>)
'''


# manual batchnorm
output_val = (input-torch_mean)/torch.sqrt(torch_var_biased)
print('output_val:', output_val)
'''
output_val: tensor([[ 1.2551,  0.0814],
        [-0.0631,  1.1820],
        [-1.1920, -1.2634]])
'''

BatchNorm源码

在第五节中我们会讲到running_mean以及running_var,所以这里先看一下BatchNorm手动定义的源码(非官方),首先我们思考这样一个问题,在训练的时候我们有一批样本batch来求均值和方差,如果是测试的时候呢?我们每次只输入一个数据,如何计算均值和方差呢?这里BatchNorm函数已经替我们考虑好了,这就涉及到了running_mean以及running_var这两个变量

训练时的BatchNorm,一共分为三步

def Batchnorm_for_train(x, gamma, beta, bn_param):
"""
param:x    : 输入数据,设shape(B,L)
param:gama : 缩放因子  γ
param:beta : 平移因子  β
param:bn_param   : batchnorm所需要的一些参数
    eps      : 接近0的数,防止分母出现0
    momentum : 动量参数,一般为0.1, 0.01, 0.001
    running_mean :滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
    running_var  : 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
"""
    running_mean = bn_param['running_mean']  #shape = [B]
    running_var = bn_param['running_var']    #shape = [B]
    results = 0. # 建立一个新的变量
 
    x_mean=x.mean(axis=0)  # 计算x的均值
    x_var=x.var(axis=0)    # 计算方差
    x_normalized=(x-x_mean)/np.sqrt(x_var+eps)       # 归一化
    results = gamma * x_normalized + beta            # 缩放平移
 
    running_mean = (1-momentum) * running_mean + momentum * x_mean
    running_var = (1-momentum) * running_var + momentum * x_var
 
    #记录新的值
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var 
 
    return results , bn_param

测试时的BatchNorm

def Batchnorm_for_test(x, gamma, beta, bn_param):
"""
param:x    : 输入数据,设shape(B,L)
param:gama : 缩放因子  γ
param:beta : 平移因子  β
param:bn_param   : batchnorm所需要的一些参数
    eps      : 接近0的数,防止分母出现0
    momentum : 动量参数,一般为0.9, 0.99, 0.999
    running_mean :滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
    running_var  : 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
"""
    running_mean = bn_param['running_mean']  #shape = [B]
    running_var = bn_param['running_var']    #shape = [B]
    results = 0. # 建立一个新的变量
 
    x_normalized=(x-running_mean )/np.sqrt(running_var +eps)       # 归一化
    results = gamma * x_normalized + beta            # 缩放平移
 
    return results , bn_param

注意:在用PyTorch来训练和测试数据时,在训练时加入model.train(),测试时加入model.eval()来让BatchNorm区分到底是训练还是测试

实例

接着上述1、2、3的内容,我们通过一个模型来查看BatchNorm的参数变化,如下所示,我们定义了一个模型,里面只有一个BatchNorm1d,我们将之前随机的样本继续作为输入,可以看到当设置affine默认为True的时候,输出是有batch.weight tensor([1., 1.])以及batch.bias tensor([0., 0.])。并且输出结果与3的结果相同,这里特别注意下面两个参数

这时我们的第一次计算,这两个参数是怎么得到的呢,还记得初始化的时候track_running_stats为True吗,这两个参数就是在他为True的时候通过momentum来计算的,如果track_running_stats为False就没有running_meanrunning_var两个变量了,公式如下

在上面我们已经计算出了均值和方差分别为:

torch_mean:------------- tensor([0.0904, 0.2407])
torch_var:----------------- tensor([0.3105, 0.1555])
torch_var_biased:------ tensor([0.2070, 0.1037])

所以

$running_mean = [0.90+0.10.0903, ~~ 0.90+0.10.2407] = [0.0090, 0.0241]$

$running_var = [0.91+0.10.2070, ~~ 0.91+0.10.1037] = [0.9207, 0.9104]$

等等!var的结果怎么和tensor([0.9310,0.9156])不一样了!还记得之前说过的无偏估计和有偏估计吗,在计算批量数据的方差时我们采用的是有偏估计,但是当用来计算running_meanrunning_var时,就要采用无偏估计了,我们计算看一下

$running_var = [0.91+0.10.3105, ~~ 0.91+0.10.1555] = [0.9311, 0.9155]$,这下就和上面的一样了

注意:由于在计算running_meanrunning_var的时候用的是无偏估计,所以分母为N1,这就要求我们的 batch_size 必须大于一,不然会出错

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.batch = nn.BatchNorm1d(2, momentum=0.1)

    def forward(self, input):
        output = self.batch(input)
        return output
    
model = Model()
print(model)
model.train()

print("1.-----------------------------------------------------------------")
out = model(input)
print("model out: ", out)
model_param = model.state_dict()
for param_tensor in model_param:
    # print key value字典
    print(param_tensor, '\t', model.state_dict()[param_tensor], '\t', model.state_dict()[param_tensor].size())
'''
1.-----------------------------------------------------------------
model out:  tensor([[ 1.2550,  0.0814],
        [-0.0631,  1.1819],
        [-1.1920, -1.2634]], grad_fn=<NativeBatchNormBackward>)
batch.weight     tensor([1., 1.])        torch.Size([2])
batch.bias       tensor([0., 0.])        torch.Size([2])
batch.running_mean       tensor([0.0090, 0.0241])        torch.Size([2])
batch.running_var        tensor([0.9310, 0.9156])        torch.Size([2])
batch.num_batches_tracked        tensor(1)       torch.Size([])
'''

下面我们再次运行上述程序的结果,由于我们的输入没有变化,在第一步已经进行归一化了,所以这一步的输出不变,但是由于第一步计算过了 running_meanrunning_var,所以这里就不是 0 和 1 了,带入以后和下面的输出是一致的,这里就不进行验证了

print("2.-----------------------------------------------------------------")
out = model(input)
print("model out: ", out)
model_param = model.state_dict()
for param_tensor in model_param:
    # print key value字典
    print(param_tensor, '\t', model.state_dict()[param_tensor], '\t', model.state_dict()[param_tensor].size())
'''
2.-----------------------------------------------------------------
model out:  tensor([[ 1.2550,  0.0814],
        [-0.0631,  1.1819],
        [-1.1920, -1.2634]], grad_fn=<NativeBatchNormBackward>)
batch.weight     tensor([1., 1.])        torch.Size([2])
batch.bias       tensor([0., 0.])        torch.Size([2])
batch.running_mean       tensor([0.0172, 0.0457])        torch.Size([2])
batch.running_var        tensor([0.8690, 0.8396])        torch.Size([2])
batch.num_batches_tracked        tensor(2)       torch.Size([])
'''

在第三步和第四步我们改变输入,model 的输出也会随之改变,有兴趣的同学可以手动算一下

print("3.-----------------------------------------------------------------")
input = torch.randn(3, 2)
out = model(input)
print("model out: ", out)
model_param = model.state_dict()
for param_tensor in model_param:
    # print key value字典
    print(param_tensor, '\t', model.state_dict()[param_tensor], '\t', model.state_dict()[param_tensor].size())
'''
3.-----------------------------------------------------------------
model out:  tensor([[-1.3844,  1.1957],
        [ 0.4426, -1.2518],
        [ 0.9419,  0.0560]], grad_fn=<NativeBatchNormBackward>)
batch.weight     tensor([1., 1.])        torch.Size([2])
batch.bias       tensor([0., 0.])        torch.Size([2])
batch.running_mean       tensor([-0.0993,  0.0332])      torch.Size([2])
batch.running_var        tensor([0.7931, 0.7779])        torch.Size([2])
batch.num_batches_tracked        tensor(3)       torch.Size([])
'''

第四步同第三步,不做过多介绍

print("4.-----------------------------------------------------------------")
input = torch.randn(3, 2)
out = model(input)
print("model out: ", out)
model_param = model.state_dict()
for param_tensor in model_param:
    # print key value字典
    print(param_tensor, '\t', model.state_dict()[param_tensor], '\t', model.state_dict()[param_tensor].size())
Model(
  (batch): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
'''
4.-----------------------------------------------------------------
model out:  tensor([[-0.3666, -1.1052],
        [ 1.3661,  1.3167],
        [-0.9995, -0.2115]], grad_fn=<NativeBatchNormBackward>)
batch.weight     tensor([1., 1.])        torch.Size([2])
batch.bias       tensor([0., 0.])        torch.Size([2])
batch.running_mean       tensor([-0.0958, -0.0104])      torch.Size([2])
batch.running_var        tensor([0.7329, 0.7390])        torch.Size([2])
batch.num_batches_tracked        tensor(4)       torch.Size([])
'''
5月 06, 2025
4月 27, 2025
ufw
4月 06, 2025
ufw