pytorch 获取网络中的模块和参数
pytorch
本文字数:650 字 | 阅读时长 ≈ 3 min

pytorch 获取网络中的模块和参数

pytorch
本文字数:650 字 | 阅读时长 ≈ 3 min

1. pytorch 获取网络结构

在写深度学习程序时,我们通常要将网络打印出来查看网络结构,一个最简单的方法就是直接 print(model) 来打印模型结构,这里我们以下面程序为例

import torch.nn as nn

class SubNet(nn.Module):
    def __init__(self):
        super(SubNet, self).__init__()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.head = SubNet()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 320)
        x = self.head(x)
        return x
    
net = Net()
print(net)

这里我们实例化了一个网络,然后直接 print(net),得到的结果如下

Net(
  (conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (head): SubNet(
    (fc1): Linear(in_features=320, out_features=50, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
  )
)

1.1 named_modules 方法

上述方法可以打印出网络结构,但是我们无法获取到每一层的更具体的信息,例如这里的 fc1 层实际名字叫做 head.fc1,我们希望迭代的获得这些信息,因此就要用到 named_modules 方法

named_modules()

下面通过例子来更加深入的理解这个方法,如下所示,继续上面的网络,我们将其每一部分输出,可以看到如下结果,以 fc1 层为例,方法返回 head.fc1 以及 Linear(in_features=320, out_features=50, bias=True)

for name, module in net.named_modules():
    print(name)
    print(module)

"""
Net(
  (conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (head): SubNet(
    (fc1): Linear(in_features=320, out_features=50, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
  )
)
conv1
Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
conv2
Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
head
SubNet(
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)
head.fc1
Linear(in_features=320, out_features=50, bias=True)
head.fc2
Linear(in_features=50, out_features=10, bias=True)
"""

1.2 named_parameters 方法

named_parameters()

由于这里的参数打印出来太占空间,因此我仅输出它们的尺寸

for name, module in net.named_parameters():
    print(name)
    print(module.size())

"""
conv1.weight
torch.Size([10, 3, 5, 5])
conv1.bias
torch.Size([10])
conv2.weight
torch.Size([20, 10, 5, 5])
conv2.bias
torch.Size([20])
head.fc1.weight
torch.Size([50, 320])
head.fc1.bias
torch.Size([50])
head.fc2.weight
torch.Size([10, 50])
head.fc2.bias
torch.Size([10])
"""
9月 09, 2024