pytorch 中 register_buffer 和 Parameter 对比
pytorch
本文字数:1k 字 | 阅读时长 ≈ 4 min

pytorch 中 register_buffer 和 Parameter 对比

pytorch
本文字数:1k 字 | 阅读时长 ≈ 4 min

本文的大部分例子来源于知乎Link

在 pytorch 模型中保存模型参数的方式如下

torch.save(model.state_dict(), path)

模型保存的是 model.state_dict()返回对象,是一个 OrderDict,他的 key 与 value 分别是模型需要保存的参数名字和值。下面介绍 parameter 和 buffer 的一些用法特点

Paramter Buffer
是否被更新
返回 model.parameters() model.buffers()
是否注册到模型中
是否随模型保存

其中 parameter 可以被 optimizer 更新,我们在优化模型参数的时候,一般都会写 SGD(model.parameters(), xxx),此外 parameter 与 buffer 均可以在保存模型参数时被保存到 OrderDict 中。下面分别介绍中这两个参数类型的区别以及如何构建

1. register_buffer

register_buffer(name, tensor, persistent=True)

buffer 的创建需要构建一个 tensor,然后将这个 tensor 注册到 buffer 中,如下

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.l1 = nn.Linear(2, 2)
        buffer = torch.randn(2, 3)  # tensor
        self.register_buffer('my_buffer', buffer)

    def forward(self, x):
        pass

model = MyModel()
for param in model.parameters():
    print(param)
for buffer in model.buffers():
    print(buffer)
print(model.state_dict())

'''
# model.parameters()
Parameter containing:
tensor([[ 0.0184, -0.3397],
        [ 0.1823, -0.2097]], requires_grad=True)
Parameter containing:
tensor([0.5309, 0.4586], requires_grad=True)

# model.buffers()
tensor([[-0.0885,  0.2578, -0.1473],
        [-0.1926,  0.2726, -0.5541]])
      
# model.state_dict()
OrderedDict([('my_buffer', tensor([[-0.0885,  0.2578, -0.1473],
        [-0.1926,  0.2726, -0.5541]])), ('l1.weight', tensor([[ 0.0184, -0.3397],
        [ 0.1823, -0.2097]])), ('l1.bias', tensor([0.5309, 0.4586]))])
'''

模型中一共有两种类型的参数

  1. 一个是 linear 操作,其中 linear 的 weight 和 bias 会随着 model.parameters 输出,并且参数可以被 optimizer 优化
  2. 一个是 buffer,buffer 类型的参数会随着 model.buffers()输出,不能被 optimizer 优化
  3. 在模型保存时,model.state_dict 可以将两种参数都保存

2. Parameter

nn.Parameter(data=None, requires_grad=True)

parameter 类型的变量也会自动注册到模型中,具有梯度,可以被 optimizer 进行优化。可以将其理解为和 linear 等参数相同的参数类型,如下

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.l1 = nn.Linear(2, 2)
        self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量

    def forward(self, x):
        # 可以通过 self.param 和 self.my_buffer 访问
        pass

model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
print(model.state_dict())

'''
# model.parameters()
Parameter containing:
tensor([[-0.4412, -2.0199,  0.7088],
        [ 0.6840,  1.0006,  0.1266],
        [ 0.9492, -0.0404, -0.6280]], requires_grad=True)
Parameter containing:
tensor([[ 0.6309, -0.1017],
        [ 0.1819,  0.3834]], requires_grad=True)
Parameter containing:
tensor([0.5768, 0.1148], requires_grad=True)

# model.state_dict()
OrderedDict([('param', tensor([[-0.4412, -2.0199,  0.7088],
        [ 0.6840,  1.0006,  0.1266],
        [ 0.9492, -0.0404, -0.6280]])), ('l1.weight', tensor([[ 0.6309, -0.1017],
        [ 0.1819,  0.3834]])), ('l1.bias', tensor([0.5768, 0.1148]))])
'''

这里模型只有一种类型的参数了,即 model.parameter,没有 model.buffer 类型

  1. 一个是 linear 操作,其中 linear 的 weight 和 bias 会随着 model.parameters()输出,并且参数可以被 optimizer 优化
  2. 一个是 parameter 类型,参数会随着 model.parameters()输出,参数可以被 optimizer 优化
  3. 在模型保存时,model.state_dict 会保存上述参数

3. 一些疑问

  1. 为什么不将参数都设为 nn.Parameter,只是把不需要修改的参数设置为 requires_grad=False?

如果不想将参数进行 optimizer 的更新,设置为 buffer 类型的话会给人更直观的感觉,表达更清晰。当然如果设置为 nn.Parameter 并且 grad 设为 false 也可以

  1. 为什么不直接将不需要进行更改的参数变量设为普通 tensor 变量?

下面通过一个例子来说明,为什么必须注册为 parameters 或者 buffer

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_tensor = torch.randn(1)  # 参数直接作为模型类成员变量
        self.register_buffer('my_buffer', torch.randn(1))  # 参数注册为 buffer
        self.my_param = nn.Parameter(torch.randn(1))

    def forward(self, x):
            return x

model = MyModel()
print(model.state_dict())
model.cuda()
print(model.my_tensor)
print(model.my_buffer)

'''
OrderedDict([('my_param', tensor([-1.0101])), ('my_buffer', tensor([-0.5266]))])
tensor([-0.2454])
tensor([-0.5266], device='cuda:0')
'''

如上,如果在模型中仅设置一个普通 tensor 的话,他并不会成为模型的一部分,也不会随模型移动到 cuda 中

9月 09, 2024
9月 06, 2024