nn.Flatten() 是 PyTorch 中的一个神经网络层,用于将输入的多维张量(tensor)展平成一维张量。它通常用于将卷积层或池化层的输出展平,以便将其传递给全连接层(即线性层)。

具体作用:

  • 输入nn.Flatten() 可以接受任意维度的张量作为输入。
  • 输出:它将输入张量展平成一个一维张量,或者按照指定的起始维度展平。

参数:

  • start_dim(可选):指定从哪个维度开始展平。默认值为 1,表示从第一个维度(通常是批处理维度之后的维度)开始展平。
  • end_dim(可选):指定展平的结束维度。默认值为 -1,表示展平到最后一个维度。

示例:

假设你有一个形状为 (batch_size, channels, height, width) 的张量,通常这是卷积层的输出。如果你想将其展平成一个形状为 (batch_size, channels * height * width) 的张量,可以使用 nn.Flatten()

import torch
import torch.nn as nn

# 假设输入是一个形状为 (batch_size, channels, height, width) 的张量
input_tensor = torch.randn(32, 3, 28, 28)  # 32 是 batch_size, 3 是 channels, 28x28 是图像大小

# 使用 Flatten 层
flatten = nn.Flatten()

# 展平后的张量
output_tensor = flatten(input_tensor)

print(output_tensor.shape)  # 输出: torch.Size([32, 2352])  # 2352 = 3 * 28 * 28

在这个例子中,nn.Flatten() 将输入的 (32, 3, 28, 28) 张量展平成了 (32, 2352) 的张量,其中 23523 * 28 * 28 的结果。

总结:

nn.Flatten() 主要用于将多维张量展平成一维或指定维度的张量,以便将其传递给全连接层或其他需要一维输入的层。