1、torch.no_grad()函数的作用

torch.no_grad()函数是禁用梯度计算的上下文管理器。当我们确信不会调用Tensor.backward()时,禁用梯度计算很有用,因为它将减少计算的内存消耗。在这种模式下,即使输入的向量的requires_grad=True,每次计算的结果也将为requires_grad=False。但是,有种例外:所有工厂函数或创建新张量的函数,都不受此模式的影响。如下代码所示。

2、torch.no_grad()函数应用场景

>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False


>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False


>>> @torch.no_grad
... def tripler(x):
...     return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False


>>> # 工厂函数并不受no_grad的影响
>>> with torch.no_grad():
...     a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True

3、补充说明:什么是工厂函数?

工厂函数是用于生成tensor的函数。常见的工厂函数有torch.rand、torch.randint、torch.randn、torch.eye等,更多介绍请移步PyTorch官网介绍:https://pytorch.org/cppdocs/notes/tensor_creation.html#factory-functions