PyTorch中有很多浮点类型,例如,torch.float16、torch.float32、torch.float64等。当初始化tensor的时候,可以指定使用那种浮点数类型,如果我们不指定,那么PyTorch默认使用torch.float32类型。如下所示:

>>> torch.tensor([1.2, 3]).dtype
torch.float32

PyTorch中浮点类型默认为torch.float32,而复数tensor的默认类型为torch.complex64,这个很好理解,因为我们复数tensor需要浮点数组成实部和虚部,每一个都是32,所以复数tensor默认就是64,如下所示:

>>> torch.tensor([1.2, 3j]).dtype
torch.complex64

利用torch.set_default_dtype可以修改tensor的浮点类型,同样的,复数tensor也会受到影响,如下所示:

>>> torch.set_default_dtype(torch.float64)

>>> torch.tensor([1.2, 3]).dtype
torch.float64

>>> torch.tensor([1.2, 3j]).dtype
torch.complex128