在 PyTorch 中,张量的存储方式是按照内存中的连续块来组织的。当我们对张量进行转置操作(如 tensor.t() 或 tensor.transpose())时,PyTorch 并不会实际改变数据的存储顺序,而是通过修改张量的元数据(如步长 stride 和形状 shape)来实现转置的效果。这种设计是为了避免不必要的数据复制,从而提高效率。然而,这种操作会导致张量的内存布局变得非连续(non-contiguous)。

下面我们详细解释为什么转置后的张量是非连续的,以及如何解决这个问题。

1、什么是连续张量?

在 PyTorch 中,张量的连续性(contiguity)指的是张量的数据在内存中是否是按照逻辑顺序连续存储的。具体来说:

  • 连续张量:数据在内存中是按照张量的逻辑顺序(行优先或列优先)连续存储的。
  • 非连续张量:数据在内存中的存储顺序与张量的逻辑顺序不一致。

PyTorch 使用 步长(stride) 来描述张量在内存中的存储方式。步长是一个元组,表示在每个维度上移动一个元素需要跳过的内存位置。

2、为什么转置后的张量是非连续的?

当我们对张量进行转置操作时,PyTorch 只是修改了张量的形状(shape)和步长(stride),而不会改变数据在内存中的实际存储顺序。

示例

import torch

# 创建一个 2x3 的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("原始张量:")
print(x)
print("形状:", x.shape)
print("步长:", x.stride())
# 输出:
# 原始张量:
# tensor([[1, 2, 3],
#         [4, 5, 6]])
# 形状: torch.Size([2, 3])
# 步长: (3, 1)  # 表示在第一个维度上移动一行需要跳过 3 个元素,在第二个维度上移动一列需要跳过 1 个元素

# 对张量进行转置
y = x.t()
print("\n转置后的张量:")
print(y)
print("形状:", y.shape)
print("步长:", y.stride())
# 输出:
# 转置后的张量:
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])
# 形状: torch.Size([3, 2])
# 步长: (1, 3)  # 表示在第一个维度上移动一行需要跳过 1 个元素,在第二个维度上移动一列需要跳过 3 个元素

从上面的例子可以看出:

  • 原始张量 x 的步长是 (3, 1),表示数据在内存中是按行优先存储的。
  • 转置后的张量 y 的步长是 (1, 3),表示数据在内存中是按列优先存储的。

由于转置操作只是修改了步长,而没有改变数据在内存中的实际存储顺序,因此转置后的张量是非连续的。

3、如何判断张量是否是连续的?

可以使用 is_contiguous() 方法来判断一个张量是否是连续的。

print("x 是否是连续的:", x.is_contiguous())  # 输出: True
print("y 是否是连续的:", y.is_contiguous())  # 输出: False

4、如何使转置后的张量变为连续的?

如果需要对非连续张量进行某些操作(如 view),可以调用 contiguous() 方法将其变为连续张量。contiguous() 会返回一个新的张量,其数据在内存中是连续存储的。

示例

# 将转置后的张量变为连续的
z = y.contiguous()
print("\n连续化后的张量:")
print(z)
print("形状:", z.shape)
print("步长:", z.stride())
print("z 是否是连续的:", z.is_contiguous())
# 输出:
# 连续化后的张量:
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])
# 形状: torch.Size([3, 2])
# 步长: (2, 1)  # 表示在第一个维度上移动一行需要跳过 2 个元素,在第二个维度上移动一列需要跳过 1 个元素
# z 是否是连续的: True

5、为什么需要连续张量?

某些操作(如 view)要求张量是连续的,因为它们的实现依赖于数据在内存中的连续存储。如果张量是非连续的,这些操作会抛出错误。

# 尝试对非连续张量使用 view
try:
    y.view(-1)
except RuntimeError as e:
    print("错误:", e)
# 输出:
# 错误: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

# 对连续张量使用 view
z = y.contiguous()
print(z.view(-1))
# 输出:
# tensor([1, 4, 2, 5, 3, 6])

6、总结

  • 转置操作会修改张量的步长,但不会改变数据在内存中的存储顺序,因此转置后的张量是非连续的。
  • 可以使用 is_contiguous() 方法判断张量是否是连续的。
  • 如果需要将非连续张量变为连续张量,可以调用 contiguous() 方法。
  • 某些操作(如 view)要求张量是连续的,否则会抛出错误。