1、torch.stack()函数的作用

torch.stack()函数的参数形式为:torch.stack(inputs,dim=0,out=None),其作用是将若干个形状相同的张量在dim维度上连接,生成一个扩维的张量。比如,我们原本有若干个2维张量,连接之后可以得到一个3维的张量。

2、torch.stack()函数的参数说明

(1)inputs: 待连接的张量序列。
注:python的序列数据只有list和tuple。

(2)dim: 新的维度,必须在0到len(outputs.shape)之间,即:0<=dim<len(outputs.shape)。
注:len(outputs.shape)是生成数据的维度大小,也就是outputs的维度值,它比inputs的维度要多出一个。

3、torch.stack()函数的使用

import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6]])

b = torch.tensor([[10, 20, 30], [40, 50, 60]])

print(torch.stack([a, b], dim=0))
print(torch.stack([a, b], dim=0).size())
print('#' * 24)

print(torch.stack([a, b], dim=1))
print(torch.stack([a, b], dim=1).size())
print('#' * 24)

print(torch.stack([a, b], dim=2))
print(torch.stack([a, b], dim=2).size())

结果为:

tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[10, 20, 30],
         [40, 50, 60]]])
torch.Size([2, 2, 3])
########################
tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]]])
torch.Size([2, 2, 3])
########################
tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]]])
torch.Size([2, 3, 2])

4、torch.stack()函数的规律分析

此处主要以二维向量的stack为主,因为二维向量经过stack之后变成三维向量,大家容易理解。
dim=0时,将tensor进行叠加从而形成三维向量,这种情况比较容易理解。
dim=1时,将每个tensor的第i行抽出连接组成一个新的2维tensor,然后进行叠加从而形成三维向量。
dim=2时,将每个tensor的第i列抽出连接组成一个新的2维tensor,然后进行叠加从而形成三维向量。

如果大家看不懂,可以看下面的示意图,这样很形象: