torch.nn.Module 是 PyTorch 中用于构建神经网络模型的基础类。所有的神经网络模型都应该继承自 torch.nn.Module,并且在该类中定义模型的结构和前向传播的逻辑。torch.nn.Module 提供了许多有用的功能,例如参数管理、模型保存与加载、设备管理(CPU/GPU)等。

主要功能

  1. 参数管理:

    • torch.nn.Module 会自动跟踪所有通过 torch.nn.Parameter 定义的参数。这些参数可以通过 parameters() 方法访问,通常用于传递给优化器进行训练。
    • state_dict() 方法返回一个包含模型所有参数的字典,通常用于保存和加载模型。
  2. 前向传播:

    • 子类需要实现 forward(self, *input) 方法,定义模型的前向传播逻辑。
    • 在调用模型时(例如 model(x)),forward 方法会被自动调用。
  3. 设备管理:

    • to(device) 方法可以将模型的所有参数和缓冲区移动到指定的设备(如 CPU 或 GPU)。
    • cpu()cuda() 方法分别将模型移动到 CPU 或 GPU。
  4. 模型保存与加载:

    • torch.save(model.state_dict(), 'model.pth') 可以保存模型的参数。
    • model.load_state_dict(torch.load('model.pth')) 可以加载模型的参数。
  5. 子模块管理:

    • add_module(name, module) 方法可以添加子模块。
    • children()modules() 方法可以遍历模型的子模块。
  6. 钩子(Hooks):

    • register_forward_hook(hook)register_backward_hook(hook) 可以注册前向和后向传播的钩子,用于调试或可视化。

示例代码

以下是一个简单的 torch.nn.Module 示例,展示如何定义一个简单的全连接神经网络:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # 输入层到隐藏层
        self.fc2 = nn.Linear(128, 64)   # 隐藏层到隐藏层
        self.fc3 = nn.Linear(64, 10)    # 隐藏层到输出层

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 激活函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例
model = SimpleNet()

# 打印模型结构
print(model)

# 前向传播
input_data = torch.randn(1, 784)  # 随机输入数据
output = model(input_data)
print(output)

关键方法

  • forward(self, *input): 定义模型的前向传播逻辑。
  • parameters(): 返回模型的所有参数。
  • state_dict(): 返回模型的状态字典,包含所有参数和缓冲区。
  • load_state_dict(state_dict): 加载模型的状态字典。
  • to(device): 将模型移动到指定的设备。
  • train()eval(): 设置模型为训练模式或评估模式。

总结

torch.nn.Module 是 PyTorch 中构建神经网络模型的核心类。通过继承 torch.nn.Module 并实现 forward 方法,可以轻松定义和训练复杂的神经网络模型。torch.nn.Module 提供了丰富的功能,使得模型的定义、训练、保存和加载变得非常方便。