torch.norm 是 PyTorch 中的一个函数,用于计算张量的范数(norm)。范数是对向量或矩阵的“大小”或“长度”的一种度量。torch.norm 可以计算不同类型的范数,如 L1 范数、L2 范数、Frobenius 范数等。

函数签名

torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)

参数说明

  • input (Tensor): 输入张量。
  • p (int, float, inf, -inf, 'fro', 'nuc', optional): 范数的类型。默认是 'fro'(Frobenius 范数)。

    • p=1: L1 范数(绝对值和)。
    • p=2: L2 范数(欧几里得范数)。
    • p='fro': Frobenius 范数(矩阵的 L2 范数)。
    • p='nuc': 核范数(矩阵的奇异值之和)。
    • p=inf: 无穷范数(最大绝对值)。
    • p=-inf: 负无穷范数(最小绝对值)。
  • dim (int, tuple of ints, optional): 指定计算范数的维度。如果为 None,则计算整个张量的范数。
  • keepdim (bool, optional): 是否保持输出的维度。如果为 True,则输出的维度与输入相同,除了指定的维度被缩减为 1。
  • out (Tensor, optional): 输出张量。
  • dtype (torch.dtype, optional): 返回张量的数据类型。

使用举例

1. 计算整个张量的 L2 范数

import torch

x = torch.tensor([1.0, 2.0, 3.0])
norm = torch.norm(x, p=2)
print(norm)  # 输出: tensor(3.7417)

2. 计算矩阵的 Frobenius 范数

A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
fro_norm = torch.norm(A, p='fro')
print(fro_norm)  # 输出: tensor(5.4772)

3. 计算矩阵的 L1 范数(按列计算)

A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
l1_norm = torch.norm(A, p=1, dim=0)
print(l1_norm)  # 输出: tensor([4., 6.])

4. 计算矩阵的 L2 范数(按行计算)

A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
l2_norm = torch.norm(A, p=2, dim=1)
print(l2_norm)  # 输出: tensor([2.2361, 5.0000])

5. 计算矩阵的无穷范数(按行计算)

A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
inf_norm = torch.norm(A, p=float('inf'), dim=1)
print(inf_norm)  # 输出: tensor([2., 4.])

6. 保持维度

A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
norm = torch.norm(A, p=2, dim=1, keepdim=True)
print(norm)  # 输出: tensor([[2.2361], [5.0000]])

总结

torch.norm 是一个非常灵活的函数,可以用于计算各种类型的范数,并且可以指定计算的维度。通过 keepdim 参数,还可以控制输出的维度是否与输入保持一致。