1、torch.set_printoptions方法简介

torch.set_printoptions方法是来修改pytorch中的打印选项的,就是使用print打印tensor时,我们显示的元素精度,最多显示元素个数等一系列相关操作的选项。如下所示:

>>> # 限制tensor的输出精读
>>> torch.set_printoptions(precision=2)
>>> torch.tensor([1.12345])
tensor([1.12])

>>> # 限制所展示元素的个数,当超过threshold个元素的时候,显示成省略号
>>> torch.set_printoptions(threshold=5)
>>> torch.arange(10)
tensor([0, 1, 2, ..., 7, 8, 9])

>>> # 恢复成默认值
>>> torch.set_printoptions(profile='default')
>>> torch.tensor([1.12345])
tensor([1.1235])
>>> torch.arange(10)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

2、torch.set_printoptions参数介绍

torch.set_printoptions方法的参数众多,其函数原型为:

torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)

本文接下来分别介绍各个参数的作用,欢迎大家的收藏和转发。

precision=None

此参数的意思是显示浮点tensor中元素的精度(显示到小数点后几位),默认是4

threshold=None

由于我们的进行训练网络时,tensor都会很大,里面的数据很多,不方便全部显示出来,显示超过一定的个数后就会进行折叠。此参数是指定tensor的数目超过多少时开始显示进行折叠。默认为1000。

edgeitems=None

此参数也与折叠有关,折叠后只显示前面的数据和后面的数据,此参数设置显示的前面与后面的数据的行数(注意这里是行数而不是个数),默认为3。

linewidth=None

此参数是指如果一行数据太多会插入换行符,此参数是指定每行的字符数(注意是字符数,不是数据的个数,这个一定要注意)到达多少时插入换行符,此参数对于显示时超过threshold而折叠的tensor并不适用。默认为80。

profile=None

此参数就是一个比较简便的显示方法,它有三个选项,分别为default、short、full用来满足我们的显示。

sci_mode=None

此参数是来指定显示的数字是否使用科学计数法,可以选择指定True或者False,如果选择None,那么是True还是False会由torch._tensor_str._Formatter来定义。值会自动的由框架来选择。一般默认为False。