1、什么是叶子节点

PyTorch的最大特点是动态计算图,计算图是用来描述运算的有向无环图。计算图有两种主要元素:结点(Node)和边(Edge)。结点表示数据,例如张量,而边表示运算,例如加、减、乘、除、卷积等。

对于节点而言,又分为叶子节点和非叶子节点。我们通常关注的叶子节点,那什么叶子节点呢?PyTorch中的张量tensor有一个属性是is_leaf,当is_leaf为True时,该tensor是叶子张量,也叫叶子节点。

2、叶子节点的作用

PyTorch具有自动求导的功能,当requires_grad=True时,PyTorch会自动记录运算过程,缓存运算中的中间参数,为自动求导做准备。但是,只有is_leaf=True和requires_grad=True同时满足时,我们才可以获得该节点的导数值。否则,对于非叶子节点而言,PyTorch出于节省内存的考虑,通常不会保存节点的到数值。

3、注意事项

判断叶子节点有两个标准:

(1)如果是用户自己创建的张量,requires_grad无论是True还是False,都是叶子节点。

需要说明一下:默认情况下,我们创建的张量tensor的requires_grad都是False值的,因为我们训练网络训练的是网络模型的权重,而不需要训练输入。我们创建的张量都是叶子张量,如下tensor的定义所示:

torch.tensor(data, dtype=None, device=None,requires_grad=False)

(2)如果不是用户自己创建的张量,例如张量b,b = a + 1,当设置requires_grad为False时候,它则变为叶子节点。