torch.arange 是 PyTorch 中的一个函数,用于生成一个一维张量(tensor),其中包含从起始值到结束值(不包括结束值)的等间隔序列。这个函数类似于 Python 中的 range 函数,但返回的是一个 PyTorch 张量。

函数签名

torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor

参数说明

  • start (float, optional): 序列的起始值,默认为 0。
  • end (float): 序列的结束值(不包括该值)。
  • step (float, optional): 序列中相邻两个元素之间的步长,默认为 1。
  • out (Tensor, optional): 输出张量。如果提供,结果将写入这个张量中。
  • dtype (torch.dtype, optional): 返回张量的数据类型。如果未指定,则根据输入参数推断数据类型。
  • layout (torch.layout, optional): 返回张量的内存布局。默认为 torch.strided
  • device (torch.device, optional): 返回张量所在的设备(如 CPU 或 GPU)。默认为当前默认设备。
  • requires_grad (bool, optional): 如果为 True,则返回的张量将记录操作以支持自动求导。默认为 False

返回值

返回一个一维张量,包含从 startend(不包括 end)的等间隔序列。

示例

1、 基本用法

import torch

# 生成从 0 到 4 的整数序列
x = torch.arange(5)
print(x)  # 输出: tensor([0, 1, 2, 3, 4])

2、 指定起始值和步长

# 生成从 1 到 10,步长为 2 的序列
x = torch.arange(1, 10, 2)
print(x)  # 输出: tensor([1, 3, 5, 7, 9])

3、 指定数据类型

# 生成从 0 到 4 的浮点数序列
x = torch.arange(5, dtype=torch.float32)
print(x)  # 输出: tensor([0., 1., 2., 3., 4.])

4、 指定设备

# 生成从 0 到 4 的序列,并将其放在 GPU 上
x = torch.arange(5, device='cuda')
print(x)  # 输出: tensor([0, 1, 2, 3, 4], device='cuda:0')

5、 支持自动求导

# 生成从 0 到 4 的序列,并启用自动求导
x = torch.arange(5, requires_grad=True)
print(x)  # 输出: tensor([0, 1, 2, 3, 4], requires_grad=True)

注意事项

  • torch.arange 生成的序列不包括 end 值。
  • 如果 startend 的值相等,则返回一个空张量。
  • step 可以是正数或负数,但必须与 startend 的方向一致,否则会返回空张量。

torch.range 的区别

torch.rangetorch.arange 的旧版本,它在生成序列时包括 end 值。由于 torch.range 的行为与 Python 的 range 函数不一致,因此在 PyTorch 1.11 版本中已被弃用,建议使用 torch.arange 代替。

# 不推荐使用 torch.range
x = torch.range(1, 5)  # 输出: tensor([1, 2, 3, 4, 5])

总之,torch.arange 是一个非常有用的函数,用于生成等间隔的数值序列,广泛应用于深度学习模型的构建和数据处理中。