torch.arange
是 PyTorch 中的一个函数,用于生成一个一维张量(tensor),其中包含从起始值到结束值(不包括结束值)的等间隔序列。这个函数类似于 Python 中的 range
函数,但返回的是一个 PyTorch 张量。
函数签名
torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
参数说明
- start (float, optional): 序列的起始值,默认为 0。
- end (float): 序列的结束值(不包括该值)。
- step (float, optional): 序列中相邻两个元素之间的步长,默认为 1。
- out (Tensor, optional): 输出张量。如果提供,结果将写入这个张量中。
- dtype (
torch.dtype
, optional): 返回张量的数据类型。如果未指定,则根据输入参数推断数据类型。 - layout (
torch.layout
, optional): 返回张量的内存布局。默认为torch.strided
。 - device (
torch.device
, optional): 返回张量所在的设备(如 CPU 或 GPU)。默认为当前默认设备。 - requires_grad (bool, optional): 如果为
True
,则返回的张量将记录操作以支持自动求导。默认为False
。
返回值
返回一个一维张量,包含从 start
到 end
(不包括 end
)的等间隔序列。
示例
1、 基本用法
import torch
# 生成从 0 到 4 的整数序列
x = torch.arange(5)
print(x) # 输出: tensor([0, 1, 2, 3, 4])
2、 指定起始值和步长
# 生成从 1 到 10,步长为 2 的序列
x = torch.arange(1, 10, 2)
print(x) # 输出: tensor([1, 3, 5, 7, 9])
3、 指定数据类型
# 生成从 0 到 4 的浮点数序列
x = torch.arange(5, dtype=torch.float32)
print(x) # 输出: tensor([0., 1., 2., 3., 4.])
4、 指定设备
# 生成从 0 到 4 的序列,并将其放在 GPU 上
x = torch.arange(5, device='cuda')
print(x) # 输出: tensor([0, 1, 2, 3, 4], device='cuda:0')
5、 支持自动求导
# 生成从 0 到 4 的序列,并启用自动求导
x = torch.arange(5, requires_grad=True)
print(x) # 输出: tensor([0, 1, 2, 3, 4], requires_grad=True)
注意事项
torch.arange
生成的序列不包括end
值。- 如果
start
和end
的值相等,则返回一个空张量。 step
可以是正数或负数,但必须与start
和end
的方向一致,否则会返回空张量。
与 torch.range
的区别
torch.range
是 torch.arange
的旧版本,它在生成序列时包括 end
值。由于 torch.range
的行为与 Python 的 range
函数不一致,因此在 PyTorch 1.11 版本中已被弃用,建议使用 torch.arange
代替。
# 不推荐使用 torch.range
x = torch.range(1, 5) # 输出: tensor([1, 2, 3, 4, 5])
总之,torch.arange
是一个非常有用的函数,用于生成等间隔的数值序列,广泛应用于深度学习模型的构建和数据处理中。