1、torch.no_grad()函数的作用
torch.no_grad()函数是禁用梯度计算的上下文管理器。当我们确信不会调用Tensor.backward()时,禁用梯度计算很有用,因为它将减少计算的内存消耗。在这种模式下,即使输入的向量的requires_grad=True,每次计算的结果也将为requires_grad=False。但是,有种例外:所有工厂函数或创建新张量的函数,都不受此模式的影响。如下代码所示。
2、torch.no_grad()函数应用场景
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # 工厂函数并不受no_grad的影响
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True
3、补充说明:什么是工厂函数?
工厂函数是用于生成tensor的函数。常见的工厂函数有torch.rand、torch.randint、torch.randn、torch.eye
等,更多介绍请移步PyTorch官网介绍:https://pytorch.org/cppdocs/notes/tensor_creation.html#factory-functions