在 PyTorch 中,chunk 是一个用于将张量分割成多个块的函数。它的用法如下所示。

1、chunk函数定义

1.1、函数签名

torch.chunk(input, chunks, dim=0)

1.2、参数说明

  • input (Tensor): 要分割的输入张量。
  • chunks (int): 要分割的块数。
  • dim (int, optional): 沿着哪个维度进行分割,默认为 0。

1.3、返回值

返回一个包含 chunks 个张量的元组。

2、chunk 代码示例

import torch

# 创建一个形状为 (6, 3) 的张量
x = torch.arange(18).reshape(6, 3)
print("Original Tensor:")
print(x)

# 将张量沿着第 0 维度分割成 3 块
chunks = torch.chunk(x, chunks=3, dim=0)
print("\nChunks:")
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1}:")
    print(chunk)

输出:

Original Tensor:
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]])

Chunks:
Chunk 1:
tensor([[0, 1, 2],
        [3, 4, 5]])
Chunk 2:
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
Chunk 3:
tensor([[12, 13, 14],
        [15, 16, 17]])

3、注意事项

  • 如果 chunks 大于指定维度的长度,则返回的块数会少于 chunks,且每个块的大小可能不同。
  • 如果 chunks 不能整除指定维度的长度,则最后一个块会较小。

示例:不能整除的情况

# 将张量沿着第 0 维度分割成 4 块
chunks = torch.chunk(x, chunks=4, dim=0)
print("\nChunks (when not divisible):")
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1}:")
    print(chunk)

输出

Chunks (when not divisible):
Chunk 1:
tensor([[0, 1, 2],
        [3, 4, 5]])
Chunk 2:
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
Chunk 3:
tensor([[12, 13, 14]])
Chunk 4:
tensor([[15, 16, 17]])

4、总结

torch.chunk 是一个方便的工具,用于将张量沿指定维度分割成多个块。它在处理大规模数据时非常有用,尤其是在需要将数据分批处理时。