pytorch 检查点 checkpoint
pytorch
本文字数:2.4k 字 | 阅读时长 ≈ 10 min

pytorch 检查点 checkpoint

pytorch
本文字数:2.4k 字 | 阅读时长 ≈ 10 min

torch.utils.checkpoint 官方文档

1. Checkpoint 检查点

PyTorch 中的 checkpoint 机制用于高效地管理内存。它在前向传播过程中不保留中间激活值,从而节省内存。与常规方法不同,它不保存整个计算图的所有中间激活值,而是在反向传播过程中重新计算这些值。这意味着在训练过程中,checkpoint 操作可以减少内存占用,但需要在反向传播时重新计算的时间成本。这种操作可以应用于网络的任意部分

在前向传播过程中,checkpoint 操作将以 torch.no_grad 模式运行,即不保存中间激活值。相反,它会保存输入元组和 function 参数。在反向传播过程中,先前保存的输入元组和 function 参数将被重新提取,然后重新计算前向传播的中间激活值。接着,根据这些重新计算的中间激活值计算梯度

在 PyTorch 进行深度学习训练时,显存开销主要包括四个部分:模型参数(parameters)、模型参数的梯度(gradients)、优化器状态(optimizer states)和中间激活值(intermediate activations)。通过使用 checkpoint 技术,我们可以利用 PyTorch 提供的 torch.no_grad 模式避免将这部分运算记录到反向图中,从而减少对中间激活值的存储需求

需要注意的是,在前向传播过程中,autograd 会记录反向传播所需的一些信息和中间变量。反向传播完成后,用于计算梯度的中间结果将被释放。这意味着模型参数、优化器状态和参数梯度始终占用存储空间,而中间激活值在反向传播完成后会自动被清空

2. torch.utils.checkpoint

Checkpoint 机制的步骤可以总结如下:

  1. 添加中间层: 在原始模型和整体计算图之间添加一个中间层,用于信息交互
  2. 记录目标层: 将原始模型的数据传输到被包裹的目标层时,数据进入 checkpoint 的 forward() 中,进行检查和记录,然后再送入目标层
  3. 目标层前向传播: 目标层在非梯度模式下执行前向传播,新创建的 tensor 不会记录梯度信息
  4. 前向传播: 目标层的结果通过 checkpoint 的前向传播输出,送入模型后续的其他结构
  5. 反向传播: 执行反向传播,计算损失求导、链式回传和梯度
  6. 目标层反向传播: 回传的梯度被送入 checkpoint 的 backward()函数
  7. 目标层第二次前向传播: 为了获取目标层的梯度信息,需要在梯度状态下对目标层进行一次前向传播。通过执行 torch.autograd.backward(outputs_with_grad, args_with_grad),将回传的梯度和目标层的输出一起计算,从而获得对应输入的梯度信息
  8. 计算目标层反向梯度: 将目标操作输入的梯度信息按照 checkpoint 本身 Function 的 backward 需求,使用 None 对其他辅助参数的梯度占位后进行返回
  9. 完成反向传播: 返回的梯度将沿着反向传播路径送入对应操作的 backward 中,逐层回传累加到各个叶子节点上

3. 说明

  1. 在反向传播过程中,checkpoint 通过重新运行每个检查段的前向传播计算来实现。这可能导致连续状态(如 RNG 状态)比没有检查点的状态更高级
  2. 默认情况下,checkpoint 包含处理 RNG 状态的逻辑,这样通过使用 RNG(如 dropout)进行的检查点传递与非检查点传递相比具有确定的输出。但是,存储和还原 RNG 状态的逻辑可能导致性能下降
  3. 如果不需要与非检查点传递相比确定的输出,可以设置 preserve_rng_state=False,来忽略在每个检查点期间隐藏和恢复 RNG 状态。这可以确保在 checkpoint 中保存 dropout 这样的 RNG 状态逻辑,使得前后两次运行的结果一致
  4. 隐藏逻辑将当前设备以及所有 CUDA 张量参数的设备备的 RNG 状态保存并恢复到 run_fn。但是,该逻辑无法预料用户是否会在 run_fn 本身内将张量移动到新设备里。因此,如果在 run_fn 内将张量移动到新的设备里(新设备指不属于集合[当前设备+张量参数的设备]的设备备),则与非检查点传递相比确定的输出将不再确保是确定的
  5. 在使用 checkpoint 时,应避免在 run_fn 内随意修改张量的 device,否则 preserve_rng_state 参数可能失效,导致结果无法事先确定

通过了解这些注意事项,可以在使用 checkpoint 时避免一些潜在的问题,确保模型训练过程中内存管理和梯度计算的正确性。

4. 使用示例

注意:第一层建议不要使用 checkpoint,dropout 和 batch normalization 层不能用 checkpoint(二者起冲突)。

checkpoint 函数咋新版本的 pytorch 中有 use_reentrant 参数,如果设置为 False,那么在前向传播过程中不会记录 autograd 图,而是在 checkpoint 函数中记录,这样可以减少内存占用,但是在反向传播时会重新计算前向传播的中间激活值,这样会增加计算时间。如果设置为 True,那么在前向传播过程中会记录 autograd 图,这样可以减少计算时间,但是会增加内存占用,一般推荐设置为 use_reentrant=False

下面通过一个例子直接展示一下 checkpoint 的使用方法

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.checkpoint import checkpoint

# 定义一个具有残差连接的卷积层
class ResidualConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualConv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out += identity
        out = self.relu(out)
        return out

# 创建一个具有残差连接的网络
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.res1 = ResidualConv(64, 64)
        self.res2 = ResidualConv(64, 64)
        self.res3 = ResidualConv(64, 64)

    def forward(self, x):
        x = self.conv1(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        return x

# 使用检查点的网络
class CheckpointResNet(nn.Module):
    def __init__(self):
        super(CheckpointResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.res1 = ResidualConv(64, 64)
        self.res2 = ResidualConv(64, 64)
        self.res3 = ResidualConv(64, 64)

    def forward(self, x):
        x = self.conv1(x)
        x = checkpoint(self.res1, x, use_reentrant=False)
        x = checkpoint(self.res2, x, use_reentrant=False)
        x = checkpoint(self.res3, x, use_reentrant=False)
        return x

# 创建输入张量
input = Variable(torch.randn(1, 3, 224, 224), requires_grad=True)

# 创建网络
resnet = ResNet()
checkpoint_resnet = CheckpointResNet()

# 比较内存消耗
torch.cuda.reset_peak_memory_stats()
resnet.cuda()
output = resnet(input.cuda())
output.backward(torch.ones_like(output))
print("Memory usage without checkpoint:", torch.cuda.max_memory_allocated()/1024/1024, "MB")

torch.cuda.reset_peak_memory_stats()
checkpoint_resnet.cuda()
output = checkpoint_resnet(input.cuda())
output.backward(torch.ones_like(output))
print("Memory usage with checkpoint:", torch.cuda.max_memory_allocated()/1024/1024, "MB")

输出如下

use_reentrant=False
Memory usage without checkpoint: 180.2099609375 MB
Memory usage with checkpoint: 157.41796875 MB

use_reentrant=True
Memory usage without checkpoint: 180.2099609375 MB
Memory usage with checkpoint: 169.9501953125 MB

5. 注意事项

当网络较小时,使用 checkpoint 可能会导致显存占用增加。这主要是因为 checkpoint 的工作原理是将网络分成多个段,每个段的激活值都会在计算过程中存储。这有助于减少网络中的激活值数量,从而减少显存占用。然而,这种方法的效果在较小的网络中可能并不明显,因为小网络本身的激活值数量较少。

此外,checkpoint 还会在计算图中引入额外的结构,这可能会导致显存占用增加。这在小网络中可能更为明显,因为额外的计算图结构相对于整个网络的显存占用更为显著。

总之,当网络较小时,使用 checkpoint 可能无法显著降低显存占用,甚至可能导致显存占用增加。在这种情况下,你可以考虑不使用 checkpoint,或者尝试其他方法来减少显存占用,例如使用更小的批量大小(batch size)或使用混合精度训练。

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.checkpoint import checkpoint

# 创建一个简单的网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

# 使用检查点的网络
class CheckpointNet(nn.Module):
    def __init__(self):
        super(CheckpointNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

    def forward(self, x):
        x = checkpoint(self.conv1, x, use_reentrant=False)
        x = checkpoint(self.conv2, x, use_reentrant=False)
        x = checkpoint(self.conv3, x, use_reentrant=False)
        x = checkpoint(self.conv4, x, use_reentrant=False)
        return x

# 创建输入张量
input = Variable(torch.randn(1, 3, 224, 224), requires_grad=True)

# 创建网络
simple_net = SimpleNet()
checkpoint_net = CheckpointNet()

# 比较内存消耗
torch.cuda.reset_peak_memory_stats()
simple_net.cuda()
output = simple_net(input.cuda())
output.backward(torch.ones_like(output))
print("Memory usage without checkpoint:", torch.cuda.max_memory_allocated()/1024/1024, "MB")

torch.cuda.reset_peak_memory_stats()
checkpoint_net.cuda()
output = checkpoint_net(input.cuda())
output.backward(torch.ones_like(output))
print("Memory usage with checkpoint:", torch.cuda.max_memory_allocated()/1024/1024, "MB")

输出结果如下,可以看到在网络很小的时候,反而没有节省显存

use_reentrant=False
Memory usage without checkpoint: 130.78662109375 MB
Memory usage with checkpoint: 131.64794921875 MB

use_reentrant=True
Memory usage without checkpoint: 130.78662109375 MB
Memory usage with checkpoint: 143.89794921875 MB
9月 09, 2024
9月 06, 2024