pytorch 中 Sequential 和 ModuleList
pytorch
本文字数:774 字 | 阅读时长 ≈ 3 min

pytorch 中 Sequential 和 ModuleList

pytorch
本文字数:774 字 | 阅读时长 ≈ 3 min

在介绍 nn.Sequential 和 nn.ModuleDict 之前,我们需要知道在 pytorch 构建的 model 核心是 nn.Module 模块,下面举个例子

class model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(3, 20, 5)

    def forward(self, x):
        x = F.relu(conv(x))
        return x

在了解这个基本概念之后,我们分别介绍这两个模块

nn.Sequential

nn.Sequential 继承自 nn.Module 模块,因此他自带 forward 函数,下面我们看一个例子

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(model)
'''
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
'''

# 给每一步的模块进行命名
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
print(model)
'''
Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)
'''

input = torch.randn([1, 1, 10, 10])
output = model(input)
print(output.size()) # torch.Size([1, 64, 2, 2])

如上所示,我们可以得到一些结论

  1. 在 nn.Sequential 里面的每一个操作是逐步执行的,不可改变顺序,如果第一步的输出与第二步的输入不匹配就会报错
  2. 可以通过 OrderedDict 来改变 nn.Sequential 里面每一步的名字。注意,即使改变了名字,索引时也需要用 0,1,2…,例如 model[0]=Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1),model[‘conv1’]会报错

nn.ModuleList

nn.ModuleDict 没有继承自 nn.Module,所以不能像 nn.Sequential 那样有 forward 功能。可以将其看做一个列表的形式,能够将多个操作存放在一个列表里

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i](x)
        return x

model = MyModule()
input = torch.randn([1, 10])
output = model(input)
print(output.size())  # torch.Size([1, 10])

如上所示,这里总结 nn.ModelList 的一些特点

  1. nn.ModelList 是单纯的列表形式,当我们想快速构建一些操作(例如例子中的 linear 操作时,可以使用 modellist)
  2. nn.ModelList 不具备 forward 功能,所以我们调用里面的操作时,需要进行索引,然后才能运行这个操作
  3. nn.ModelList 列表内的操作可以是乱序的,比如我先用 list[3],再用 list[0],而 nn.Sequential 的执行顺序不能打乱

为什么不能用 python 中的 list 来代替 nn.ModelList 呢?

因为 nn.ModelList 可以将里面的列表操作自动注册到整个网络中,但是如果是 python 的 list,则会出问题,如下

class net_modlist(nn.Module):
    def __init__(self):
        super(net_modlist, self).__init__()
        self.modlist = nn.ModuleList([
                    nn.Conv2d(1, 20, 5),
                    nn.Conv2d(20, 64, 5),])

    def forward(self, x):
        for m in self.modlist:
            x = m(x)
        return x

model = net_modlist()
for param in model.parameters():
    print(type(param.data), param.size())

'''
nn.ModuleList
<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])
<class 'torch.Tensor'> torch.Size([64])

将nn.ModuleList换为单纯的list
None # 输出为None,表示conv操作并没有加入到模型参数中
'''
9月 09, 2024
9月 06, 2024