torch.as_strided 是 PyTorch 中的一个函数,用于创建一个新的张量(tensor),该张量与输入张量共享相同的数据存储,但具有不同的形状和步幅(stride)。这个函数通常用于实现一些高级的张量操作,比如滑动窗口、矩阵转置等。

参数说明

  • input (Tensor): 输入张量。
  • size (tuple of ints): 输出张量的形状。
  • stride (tuple of ints): 输出张量的步幅。
  • storage_offset (int, optional): 输出张量在存储中的起始偏移量,默认为 0。

返回值

返回一个与输入张量共享数据存储的新张量,但具有指定的形状和步幅。

使用场景

1、 滑动窗口操作:比如在卷积神经网络中,可以使用 torch.as_strided 来创建一个滑动窗口视图,而不需要实际复制数据。
2、 矩阵转置:通过调整步幅,可以实现矩阵的转置操作。
3、 高级索引操作:可以通过调整步幅和形状来实现一些复杂的索引操作。

示例

import torch

# 创建一个 1D 张量
x = torch.arange(10)
print("Original tensor:", x)

# 使用 as_strided 创建一个 2x5 的视图
y = torch.as_strided(x, size=(2, 5), stride=(5, 1))
print("Strided tensor:", y)

# 修改 y 会影响到 x
y[0, 0] = 100
print("Modified original tensor:", x)

输出

Original tensor: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Strided tensor: tensor([[0, 1, 2, 3, 4],
                       [5, 6, 7, 8, 9]])
Modified original tensor: tensor([100,   1,   2,   3,   4,   5,   6,   7,   8,   9])

注意事项

  • torch.as_strided 创建的新张量与原始张量共享数据存储,因此修改新张量会影响到原始张量。
  • 使用 torch.as_strided 时需要小心,因为不正确的步幅设置可能导致访问到无效的内存区域,从而引发错误或未定义行为。

这个函数在需要高效操作张量时非常有用,但也需要谨慎使用,以避免潜在的内存问题。