在 PyTorch 中,你可以使用多种方法来合并向量(或张量)。以下是几种常见的方法:
1. 使用 torch.cat
进行拼接
torch.cat
可以沿着指定的维度拼接张量。
import torch
# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 沿着第0维拼接
c = torch.cat((a, b), dim=0)
print(c) # 输出: tensor([1, 2, 3, 4, 5, 6])
2. 使用 torch.stack
进行堆叠
torch.stack
可以在新的维度上堆叠张量。
import torch
# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 在第0维上堆叠
c = torch.stack((a, b), dim=0)
print(c) # 输出: tensor([[1, 2, 3],
# [4, 5, 6]])
3. 使用 torch.hstack
进行水平拼接
torch.hstack
可以将张量在水平方向(第1维)上拼接。
import torch
# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 水平拼接
c = torch.hstack((a, b))
print(c) # 输出: tensor([1, 2, 3, 4, 5, 6])
4. 使用 torch.vstack
进行垂直拼接
torch.vstack
可以将张量在垂直方向(第0维)上拼接。
import torch
# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 垂直拼接
c = torch.vstack((a, b))
print(c) # 输出: tensor([[1, 2, 3],
# [4, 5, 6]])
5. 使用 torch.column_stack
进行列拼接
torch.column_stack
可以将张量按列拼接。
import torch
# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 列拼接
c = torch.column_stack((a, b))
print(c) # 输出: tensor([[1, 4],
# [2, 5],
# [3, 6]])
6. 使用 torch.row_stack
进行行拼接
torch.row_stack
可以将张量按行拼接。
import torch
# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 行拼接
c = torch.row_stack((a, b))
print(c) # 输出: tensor([[1, 2, 3],
# [4, 5, 6]])
总结
torch.cat
是最通用的拼接方法,可以指定拼接的维度。torch.stack
会在新的维度上堆叠张量。torch.hstack
和torch.vstack
分别用于水平和垂直拼接。torch.column_stack
和torch.row_stack
分别用于按列和按行拼接。
根据你的需求选择合适的方法来合并向量。