pytorch 保存和加载模型
pytorch
本文字数:1k 字 | 阅读时长 ≈ 4 min

pytorch 保存和加载模型

pytorch
本文字数:1k 字 | 阅读时长 ≈ 4 min

1. 保存与加载模型

首先给出 PyTorch 官网的两个教程:

==这里讲一种常用的方法==

1.1 保存&&加载

torch.save(x, path)

注意这个 x 可以是一个简单的 Tensor,也可以是我们的模型参数

torch.load(path)

此函数返回和之前保存的一模一样的 x 信息,即之前保存的 x 是什么,这个函数就返回什么

这里举两个例子方便理解,一个是 Tensor 的例子,另一个是 Model 的例子

1.2 Tensor

import torch
x = torch.tensor([0, 1, 2, 3, 4])
torch.save(x, 'tensor.pth')
y = torch.load('tensor.pth')
print(y)
'''
tensor([0, 1, 2, 3, 4])
'''

此时当前文件目录下会出现 tensor.pth 文件,也就是说我们用 torch.save() 保存了变量 x,然后用 torch.load() 加载赋值给 y 输出

1.3 Model

在训练模型的时候,我们往往需要保存模型的 epoch,model 参数 以及 optimizer 的信息,保存的代码如下

torch.save({'epoch': epoch, 
            'state_dict': model_restoration.state_dict(),
            'optimizer' : optimizer.state_dict()}, 
            # os.path.join(model_dir,"model_latest.pth")
            os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 

重新加载模型的程序如下

# 加载模型参数
def load_checkpoint(model, weights):
    checkpoint = torch.load(weights)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

# 加载optimizer参数
def load_optim(optimizer, weights):
    checkpoint = torch.load(weights)
    optimizer.load_state_dict(checkpoint['optimizer'])

# 加载epoch
def load_start_epoch(weights):
    checkpoint = torch.load(weights)
    epoch = checkpoint["epoch"]
    return epoch

注意上面的 load_checkpoint 函数,如果在训练时用了 DataParallel 函数,那么最终参数会带有 module,此时就应该将其去掉

没有使用 DataParallel 的参数形式

使用 DataParallel 的参数形式,可以发现参数前带有 module

我们在保存模型时都保存了些什么呢?下面程序展示了保存的模型和优化器的一些信息,从输出可以看出,我们传入 torch.save() 中的就是模型中卷积等的 weightbias 等信息。那么为什么使用 DataParallel 之后加载参数需要去掉 module 呢,这是因为我们真实的模型中是没有 module 这个前缀的,是 conv1.weight 或者 conv1.bias,而我们使用并行计算时,参数就会被归到 module 下,就变为了 module.conv1.weight 以及 module.conv1.bias,如果在 load 的时候不把前缀 module. 去掉,模型就无法匹配参数,也就没法恢复了,所以在恢复参数的时候要注意索引是否一致

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


# define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1=nn.Conv2d(3, 6, 5)
        self.pool=nn.MaxPool2d(2, 2)
        self.conv2=nn.Conv2d(6, 16, 5)
        self.fc1=nn.Linear(16*5*5, 120)
        self.fc2=nn.Linear(120, 84)
        self.fc3=nn.Linear(84, 10)

    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

def main():
    # Initialize model
    model = TheModelClass()

    # Initialize optimizer
    optimizer=optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    '''
        model的state_dict()与optimizer的略有不同
        model:
            torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数
            当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中
            的state_dict也会存放batchnorm's running_mean
        optimizer:
            state_dict字典对象包含state和param_groups的字典对象,而param_groups key
            对应的value也是一个由学习率,动量等参数组成的一个字典对象
    '''
    # print model state_dict
    print('Model.state_dict: ')
    model_param = model.state_dict()
    for param_tensor in model_param:
        # print key value字典
        print(param_tensor, '\t', model.state_dict()[param_tensor].size())

    # print optimizer state_dict
    print('Optimizer state_dict: ')
    optim_param = optimizer.state_dict()
    for var_name in optim_param:
        print(var_name, '\t', optimizer.state_dict()[var_name])

if __name__=='__main__':
    main()
'''
Model.state_dict: 
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias       torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias       torch.Size([16])
fc1.weight       torch.Size([120, 400])
fc1.bias         torch.Size([120])
fc2.weight       torch.Size([84, 120])
fc2.bias         torch.Size([84])
fc3.weight       torch.Size([10, 84])
fc3.bias         torch.Size([10])
Optimizer state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
'''
9月 09, 2024
9月 06, 2024