nn.Flatten()
是 PyTorch 中的一个神经网络层,用于将输入的多维张量(tensor)展平成一维张量。它通常用于将卷积层或池化层的输出展平,以便将其传递给全连接层(即线性层)。
具体作用:
- 输入:
nn.Flatten()
可以接受任意维度的张量作为输入。 - 输出:它将输入张量展平成一个一维张量,或者按照指定的起始维度展平。
参数:
start_dim
(可选):指定从哪个维度开始展平。默认值为1
,表示从第一个维度(通常是批处理维度之后的维度)开始展平。end_dim
(可选):指定展平的结束维度。默认值为-1
,表示展平到最后一个维度。
示例:
假设你有一个形状为 (batch_size, channels, height, width)
的张量,通常这是卷积层的输出。如果你想将其展平成一个形状为 (batch_size, channels * height * width)
的张量,可以使用 nn.Flatten()
。
import torch
import torch.nn as nn
# 假设输入是一个形状为 (batch_size, channels, height, width) 的张量
input_tensor = torch.randn(32, 3, 28, 28) # 32 是 batch_size, 3 是 channels, 28x28 是图像大小
# 使用 Flatten 层
flatten = nn.Flatten()
# 展平后的张量
output_tensor = flatten(input_tensor)
print(output_tensor.shape) # 输出: torch.Size([32, 2352]) # 2352 = 3 * 28 * 28
在这个例子中,nn.Flatten()
将输入的 (32, 3, 28, 28)
张量展平成了 (32, 2352)
的张量,其中 2352
是 3 * 28 * 28
的结果。
总结:
nn.Flatten()
主要用于将多维张量展平成一维或指定维度的张量,以便将其传递给全连接层或其他需要一维输入的层。