1、PyTorch自动求导代码

import torch

input = torch.ones([2, 2], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
w3 = torch.tensor(4.0, requires_grad=True)

l1 = input * w1
l2 = l1 + w2
l3 = l1 * w3
l4 = l2 * l3
loss = l4.mean()

loss.backward()

print(w1.grad, w2.grad, w3.grad)

结果为:

tensor(28.) tensor(8.) tensor(10.)

2、PyTorch自动求导的计算图

注意:图中的圆形表示操作符,因为执行操作符之后必定会出现一个中间结果,所以圆形代表操作符和操作数(中间结果)。中间结果非用户生成,属于非叶子节点。而方框表示的操作数为叶子节点。

3、链式求导计算公式

4、参考

部分内容参考自互联网。