pytorch 计算图和反向传播
pytorch
本文字数:1.5k 字 | 阅读时长 ≈ 6 min

pytorch 计算图和反向传播

pytorch
本文字数:1.5k 字 | 阅读时长 ≈ 6 min

参考

1. 计算图概念

Pytorch 的核心功能是自动求导机制即反向传播图/计算图,计算图在程序中并非一个实体(一个类/对象),而是由众多元素和内部运算机制组成的抽象概念.

Pytorch 的计算图由节点和边组成,节点表示张量或者 Function,边表示张量和 Function 之间的依赖关系。

Pytorch 中的计算图是动态图。这里的动态主要有两重含义。

2. pytorch 实例解释

下面先介绍 pytorch 中 tensor 的一些属性

名称 描述
tensor/tensor.data 获得该节点的值,即 Tensor 类型的值
tensor.grad 获得该节点处的梯度信息,代表$\frac{\partial loss}{\partial x}$
tensor. requires_grad True 代表此变量处需要计算梯度
tensor.grad_fn 表示变量是不是一个 Function 类的输出值。若是则 grad_fn 返回该 Function 类,否则是 None
tensor._version 表示该 tensor in-place 计算次数,默认为 0。tensor 成为 Function 类输入后,_version 要保持不变,否则 backward()时会报错
tensor.retains_gard True 代表此变量保留梯度,False 代表不保留梯度;默认无此属性,中间节点调用 tensor.retain_gard()后,为 True

举例,网络的计算图如下所示

import torch  

a = torch.tensor(2.0, requires_grad=False)  
b = torch.tensor(2.0, requires_grad=False)  
c = (a * b).requires_grad_(True)
d = torch.tensor(2.0, requires_grad=False)  
e = c * d  
f = torch.tensor(2.0, requires_grad=False)  
g = e * f

g.backward() # 只有标量scalar,才能运行.backward()执行求导

def print_tensor_info(tensor, name):
    print(f"{name:<10} Data: {tensor.data}")
    print(f"{name:<10} Grad: {tensor.grad}")
    print(f"{name:<10} Grad function: {tensor.grad_fn}")
    print(f"{name:<10} Is leaf: {tensor.is_leaf}")
    print(f"{name:<10} Requires grad: {tensor.requires_grad}")
    print("-------------------------------")

# 假设a, b, c, d, e, f, g是你的tensor变量
print_tensor_info(a, "a")
print_tensor_info(b, "b")
print_tensor_info(c, "c")
print_tensor_info(d, "d")
print_tensor_info(e, "e")
print_tensor_info(f, "f")
print_tensor_info(g, "g")

图例说明:

蓝色代表 requires_grad=False 的叶子节点
绿色代表 requires_grad=True 的叶子节点
棕色代表中间节点/非叶子节点
黄色代表反向传播节点(反向传播图)

当我们调用 backward 时,各个节点的属性如下

a          Data: 2.0
a          Grad: None
a          Grad function: None
a          Is leaf: True
a          Requires grad: False
-------------------------------
b          Data: 2.0
b          Grad: None
b          Grad function: None
b          Is leaf: True
b          Requires grad: False
-------------------------------
c          Data: 4.0
c          Grad: 4.0
c          Grad function: None
c          Is leaf: True
c          Requires grad: True
-------------------------------
d          Data: 2.0
d          Grad: None
d          Grad function: None
d          Is leaf: True
d          Requires grad: False
-------------------------------
e          Data: 8.0
e          Grad: None
e          Grad function: <MulBackward0 object at 0x7fd0c8b4f4c0>
e          Is leaf: False
e          Requires grad: True
-------------------------------
f          Data: 2.0
f          Grad: None
f          Grad function: None
f          Is leaf: True
f          Requires grad: False
-------------------------------
g          Data: 16.0
g          Grad: None
g          Grad function: <MulBackward0 object at 0x7fd0c8b4ff40>
g          Is leaf: False
g          Requires grad: True
-------------------------------

3. 中间节点的 grad

Pytorch 默认不会保存中间节点(intermediate variable)的 grad,此举是为了节省内存。详见 https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94

实际上由上图可以知道,反向传播过程中 grad 值不会经过中间节点,而是由 Function 类到另一个 Function 类最后抵达叶子节点。

下面看一个例子解释一下这是什么意思

input = torch.tensor(log(11), requires_grad=True)  
x = 2 * input # x.grad_fn 指向 MulBackward Function  
y = 3 + x # y.grad_fn 指向 AddBackward Function  
z = torch.exp(y) # z.grad_fn 指向 ExpBackward Function 

requires_grad=Ture 的叶子节点进入 Function 类时,就会在内存中创建计算图(蓝色节点),任何计算图上的 requires_grad=True 标量节点,都可以调用 backward() 函数
x.backward() 函数调用后会释放 x.grad_fn 对应的蓝色节点以及之前的节点,而不会释放 x 之后的蓝色节点

例子:y.backward() 会释放 AddBackward FunctionMulBackward Function,而不会改变 ExpBackward Function`

推论:叶子节点进行 backward() 不会释放任何蓝色节点,所以叶子节点可以无限次 backward(),但是 backward() 时遇到已经被释放的蓝色节点,则会报错

例子:y.backward() 后运行 z.backward() 会报错

推论:如果两个 tensor 分别 backward() 时,释放的蓝色节点没有交集,则可以运行,不会报错

backward(retain_graph=True) 代表本次反向传播不释放任何蓝色节点,可以实现对同一蓝色节点的多次利用
多次 backward() 时,梯度值累加
只有标量 scalar 才能调用 backward()
中间节点的 grad: Pytorch 默认不会保存中间节点(intermediate variable)的 grad,此举是为了节省内存。详见 https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94

实际上由上图可以知道,反向传播过程中 grad 值不会经过中间节点,而是由 Function 类到另一个 Function 类最后抵达叶子节点。

Hi Kalamaya,

By default, gradients are only retained for leaf variables. non-leaf variables’ gradients are not retained to be inspected later. This was done by design, to save memory.

However, you can inspect and extract the gradients of the intermediate variables via hooks.
You can register a function on a Variable that will be called when the backward of the variable is being processed.

More documentation on hooks is here: http://pytorch.org/docs/autograd.html#torch.autograd.Variable.register_hook 2.2k

Here’s an example of calling the print function on the variable yy to print out it’s gradient (you can also define your own function that copies the gradient over else-where or modifies the gradient, for example.

from __future__ import print_function
from torch.autograd import Variable
import torch

xx = Variable(torch.randn(1,1), requires_grad = True)
yy = 3*xx
zz = yy**2

yy.register_hook(print)
zz.backward()
Output:

Variable containing:
-3.2480
[torch.FloatTensor of size 1x1]
9月 09, 2024
9月 06, 2024