torch.tensor()torch.as_tensor() 是 PyTorch 中用于创建张量的两个函数,它们的主要区别在于如何处理输入数据和内存共享。

1. torch.tensor()

  • 功能: torch.tensor() 总是会创建一个新的张量,并且会复制输入数据。
  • 内存共享: 不会与输入数据共享内存,即使输入数据已经是张量或 NumPy 数组。
  • 适用场景: 当你需要确保新张量与输入数据完全独立时使用。
import torch
import numpy as np

data = np.array([1, 2, 3])
tensor = torch.tensor(data)
print(tensor)  # 输出: tensor([1, 2, 3])

2. torch.as_tensor()

  • 功能: torch.as_tensor() 会尝试避免复制数据,尽可能与输入数据共享内存。
  • 内存共享: 如果输入数据已经是张量或 NumPy 数组,并且数据类型兼容,torch.as_tensor() 会共享内存,而不会复制数据。
  • 适用场景: 当你希望避免不必要的内存复制时使用,尤其是在处理大型数据集时。
import torch
import numpy as np

data = np.array([1, 2, 3])
tensor = torch.as_tensor(data)
print(tensor)  # 输出: tensor([1, 2, 3], dtype=torch.int32)

总结

  • torch.tensor(): 总是创建新的张量,复制数据,不与输入数据共享内存。
  • torch.as_tensor(): 尽可能共享内存,避免复制数据,适合处理大型数据集时使用。

选择使用哪个函数取决于你是否需要与输入数据共享内存。如果你需要确保新张量与输入数据完全独立,使用 torch.tensor();如果你希望避免不必要的内存复制,使用 torch.as_tensor()