在做AI业务开发时候,常常需要pytorch的tensor类型和numpy类型进行转换,下面给大家介绍一下两者的转换过程:
首先,导入需要使用的包:
import numpy as np
import torch
然后,创建一个numpy类型的数组:
x = np.ones(3)
print(type(x))
这里创建了一个一维的数组,3个都为1,我们打印一下这个x的类型显示如下:
<class 'numpy.ndarray'>
1、numpy类型转为tensor类型
用下面的代码将上述的x转换成tensor类型:
y = torch.tensor(x)
print(type(x))
这个打印的结果是:
<class 'torch.Tensor'>
当然,也可以使用:
y = torch.from_numpy(x)
2、tensor类型转为numpy类型
import torch
x = torch.ones(3)
y = x.detach().numpy()
也可以使用:
y = x.numpy()