在做AI业务开发时候,常常需要pytorch的tensor类型和numpy类型进行转换,下面给大家介绍一下两者的转换过程:

首先,导入需要使用的包:

import numpy as np
import torch

然后,创建一个numpy类型的数组:

x = np.ones(3)
print(type(x))

这里创建了一个一维的数组,3个都为1,我们打印一下这个x的类型显示如下:

<class 'numpy.ndarray'>

1、numpy类型转为tensor类型

用下面的代码将上述的x转换成tensor类型:

y = torch.tensor(x)
print(type(x))

这个打印的结果是:

<class 'torch.Tensor'> 

当然,也可以使用:

y = torch.from_numpy(x)

2、tensor类型转为numpy类型

import torch
x = torch.ones(3) 
y = x.detach().numpy()

也可以使用:

y = x.numpy()