torch.asarray
和 torch.as_tensor
是 PyTorch 中用于将输入数据转换为张量的两个函数,但它们的行为和用途有一些区别。
1. torch.asarray
- 功能:
torch.asarray
是 PyTorch 1.10 版本引入的一个函数,用于将输入数据(如列表、NumPy 数组等)转换为 PyTorch 张量。 行为:
- 如果输入数据已经是一个 PyTorch 张量,
torch.asarray
会直接返回该张量。 - 如果输入数据是一个 NumPy 数组,
torch.asarray
会创建一个新的张量,并且默认情况下不会共享内存(即会进行数据拷贝)。 - 如果输入数据是其他类型(如列表),
torch.asarray
会将其转换为张量。
- 如果输入数据已经是一个 PyTorch 张量,
- 内存共享: 默认情况下,
torch.asarray
不会共享内存,除非输入数据已经是 PyTorch 张量。
2. torch.as_tensor
- 功能:
torch.as_tensor
也是用于将输入数据转换为 PyTorch 张量的函数。 行为:
- 如果输入数据已经是一个 PyTorch 张量,
torch.as_tensor
会直接返回该张量。 - 如果输入数据是一个 NumPy 数组,
torch.as_tensor
会尝试共享内存(即不会进行数据拷贝),除非数据类型不兼容。 - 如果输入数据是其他类型(如列表),
torch.as_tensor
会将其转换为张量。
- 如果输入数据已经是一个 PyTorch 张量,
- 内存共享:
torch.as_tensor
会尽可能共享内存,以减少内存开销。
主要区别
- 内存共享:
torch.as_tensor
会尽可能共享内存,而torch.asarray
默认情况下不会共享内存(除非输入已经是 PyTorch 张量)。 - 使用场景: 如果你希望减少内存开销并且输入数据是 NumPy 数组,
torch.as_tensor
是更好的选择。如果你希望确保数据拷贝并且不共享内存,可以使用torch.asarray
。
示例
import torch
import numpy as np
# 创建一个 NumPy 数组
np_array = np.array([1, 2, 3])
# 使用 torch.asarray
tensor1 = torch.asarray(np_array)
print(tensor1) # 输出: tensor([1, 2, 3], dtype=torch.int32)
print(tensor1.data_ptr() == np_array.__array_interface__['data'][0]) # 输出: False (不共享内存)
# 使用 torch.as_tensor
tensor2 = torch.as_tensor(np_array)
print(tensor2) # 输出: tensor([1, 2, 3], dtype=torch.int32)
print(tensor2.data_ptr() == np_array.__array_interface__['data'][0]) # 输出: True (共享内存)
在这个例子中,torch.as_tensor
共享了 NumPy 数组的内存,而 torch.asarray
创建了一个新的张量,没有共享内存。
总结来说,torch.as_tensor
更适合在需要共享内存的场景中使用,而 torch.asarray
更适合在需要确保数据独立性的场景中使用。