在PyTorch中,向量广播(Broadcasting)是一种在不同形状的张量之间执行逐元素操作的机制。广播机制使得PyTorch能够自动扩展较小的张量,使其与较大的张量形状兼容,从而进行逐元素操作,而无需显式复制数据。

1、广播规则

广播遵循以下规则:

  • 从尾部对齐:从张量的最右边维度开始,逐维度比较大小。
  • 维度大小相等或为1:如果两个张量在某个维度上的大小相等,或者其中一个张量在该维度上的大小为1,则可以进行广播。
  • 缺失维度视为1:如果两个张量的维度数不同,PyTorch会在较小张量的前面补1,使其维度数与较大张量相同。
  • 扩展大小为1的维度:在广播过程中,大小为1的维度会被扩展为与较大张量对应维度相同的大小。

2、示例

示例1:标量与向量的广播

import torch

a = torch.tensor([1, 2, 3])  # 形状为 (3,)
b = 2  # 标量,形状为 ()

# 标量b会被广播为形状(3,),相当于 [2, 2, 2]
c = a + b
print(c)  # 输出: tensor([3, 4, 5])

示例2:向量与矩阵的广播

a = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状为 (2, 3)
b = torch.tensor([1, 2, 3])  # 形状为 (3,)

# 向量b会被广播为形状(2, 3),相当于 [[1, 2, 3], [1, 2, 3]]
c = a + b
print(c)  # 输出: tensor([[2, 4, 6], [5, 7, 9]])

示例3:不同形状的张量广播

a = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])  # 形状为 (2, 1, 3)
b = torch.tensor([1, 2, 3])  # 形状为 (3,)

# 向量b会被广播为形状(2, 1, 3),相当于 [[[1, 2, 3]], [[1, 2, 3]]]
c = a + b
print(c)  # 输出: tensor([[[2, 4, 6]], [[5, 7, 9]]])

3、注意事项

  • 性能:广播机制不会实际复制数据,因此不会增加内存开销。
  • 不兼容形状:如果两个张量的形状不满足广播规则,PyTorch会抛出错误。

4、总结

PyTorch的广播机制使得在不同形状的张量之间进行逐元素操作变得非常方便。理解广播规则有助于编写更简洁、高效的代码。