1. 梯度累加
梯度累加(Gradient Accmulation)是一种增大训练时 batch size 的技巧。当 batch size 在一张卡放不下时,可以将很大的 batch size 分解为一个个小的 mini batch,分别计算每一个 mini batch 的梯度,然后将其累加起来优化
正常的 pytorch 训练流程如下(来自知乎)
for i, (image, label) in enumerate(train_loader):
pred = model(image) # 1
loss = criterion(pred, label) # 2
optimizer.zero_grad() # 3
loss.backward() # 4
optimizer.step() # 5
- 神经网络 forward 过程
- 获取 loss,通过 pred 和 label 计算你损失函数
- 清空网络中参数的梯度
- 反向传播,计算当前梯度
- 根据梯度更新网络参数
使用梯度累加的方法如下
for i,(image, label) in enumerate(train_loader):
# 1. input output
pred = model(image)
loss = criterion(pred, label)
# 2.1 loss regularization
loss = loss / accumulation_steps
# 2.2 back propagation
loss.backward()
# 3. update parameters of net
if (i+1) % accumulation_steps == 0:
# optimizer the net
optimizer.step() # update parameters of net
optimizer.zero_grad() # reset gradient
- 神经网络 forward 过程,同时计算损失函数
- 反向传播计算当前梯度(在 backward 时,计算的 loss 要除 batch 的大小得到均值)
- 不断重复 1、2 步骤,重复获取梯度
- 梯度累加到一定次数后,先 optimizer.step()更新网络参数,随后 zero_grad()清除梯度,为下一次梯度累加做准备
2. DDP 中的梯度累加
问题:在 DDP 中所有卡的梯度 all_reduce 阶段发生在 loss.bachward()阶段,也就是说执行 loss.backward()之后,所有卡的梯度会进行一次汇总,但是如果我们如果使用梯度累加策略,假设梯度累加 K=2,就需要 all_reduce 汇总两次,会带来额外的计算错误和时间开销
解决方案:知乎写的很好,这里参考其解决方案,只需要在前 K-1 次取消梯度同步即可,DDP 提供了一个暂时取消梯度同步的 context 函数 no_sync(),在这个函数下,DDP 不会进行梯度同步
model = DDP(model)
for 每次梯度累加循环
optimizer.zero_grad()
# 前accumulation_step-1个step,不进行梯度同步,每张卡分别累积梯度。
for _ in range(K-1)::
with model.no_sync():
prediction = model(data)
loss = loss_fn(prediction, label) / K
loss.backward() # 积累梯度,但是多卡之间不进行同步
# 第K个step
prediction = model(data)
loss = loss_fn(prediction, label) / K
loss.backward() # 进行多卡之间的梯度同步
optimizer.step()
优雅写法
from contextlib import nullcontext
# 如果你的python版本小于3.7,请注释掉上面一行,使用下面这个:
# from contextlib import suppress as nullcontext
if local_rank != -1:
model = DDP(model)
optimizer.zero_grad()
for i, (data, label) in enumerate(dataloader):
# 只在DDP模式下,轮数不是K整数倍的时候使用no_sync
my_context = model.no_sync if local_rank != -1 and i % K != 0 else nullcontext
with my_context():
prediction = model(data)
loss = loss_fn(prediction, label) / K
loss.backward() # 积累梯度,不应用梯度改变
if i % K == 0:
optimizer.step()
optimizer.zero_grad()
3. 梯度累加的影响
本文由 Yonghui Wang 创作,采用
知识共享署名4.0
国际许可协议进行许可
本站文章除注明转载/出处外,均为本站原创或翻译,转载前请务必署名
最后编辑时间为:
Dec 19, 2024 12:13 pm