1、torch.stack()函数的作用
torch.stack()函数的参数形式为:torch.stack(inputs,dim=0,out=None)
,其作用是将若干个形状相同的张量在dim维度上连接,生成一个扩维的张量。比如,我们原本有若干个2维张量,连接之后可以得到一个3维的张量。
2、torch.stack()函数的参数说明
(1)inputs: 待连接的张量序列。
注:python的序列数据只有list和tuple。
(2)dim: 新的维度,必须在0到len(outputs.shape)之间,即:0<=dim<len(outputs.shape)。
注:len(outputs.shape)是生成数据的维度大小,也就是outputs的维度值,它比inputs的维度要多出一个。
3、torch.stack()函数的使用
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[10, 20, 30], [40, 50, 60]])
print(torch.stack([a, b], dim=0))
print(torch.stack([a, b], dim=0).size())
print('#' * 24)
print(torch.stack([a, b], dim=1))
print(torch.stack([a, b], dim=1).size())
print('#' * 24)
print(torch.stack([a, b], dim=2))
print(torch.stack([a, b], dim=2).size())
结果为:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[10, 20, 30],
[40, 50, 60]]])
torch.Size([2, 2, 3])
########################
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]]])
torch.Size([2, 2, 3])
########################
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]]])
torch.Size([2, 3, 2])
4、torch.stack()函数的规律分析
此处主要以二维向量的stack为主,因为二维向量经过stack之后变成三维向量,大家容易理解。
dim=0时,将tensor进行叠加从而形成三维向量,这种情况比较容易理解。
dim=1时,将每个tensor的第i行抽出连接组成一个新的2维tensor,然后进行叠加从而形成三维向量。
dim=2时,将每个tensor的第i列抽出连接组成一个新的2维tensor,然后进行叠加从而形成三维向量。
如果大家看不懂,可以看下面的示意图,这样很形象: