torch.nn.Module
是 PyTorch 中用于构建神经网络模型的基础类。所有的神经网络模型都应该继承自 torch.nn.Module
,并且在该类中定义模型的结构和前向传播的逻辑。torch.nn.Module
提供了许多有用的功能,例如参数管理、模型保存与加载、设备管理(CPU/GPU)等。
主要功能
参数管理:
torch.nn.Module
会自动跟踪所有通过torch.nn.Parameter
定义的参数。这些参数可以通过parameters()
方法访问,通常用于传递给优化器进行训练。state_dict()
方法返回一个包含模型所有参数的字典,通常用于保存和加载模型。
前向传播:
- 子类需要实现
forward(self, *input)
方法,定义模型的前向传播逻辑。 - 在调用模型时(例如
model(x)
),forward
方法会被自动调用。
- 子类需要实现
设备管理:
to(device)
方法可以将模型的所有参数和缓冲区移动到指定的设备(如 CPU 或 GPU)。cpu()
和cuda()
方法分别将模型移动到 CPU 或 GPU。
模型保存与加载:
torch.save(model.state_dict(), 'model.pth')
可以保存模型的参数。model.load_state_dict(torch.load('model.pth'))
可以加载模型的参数。
子模块管理:
add_module(name, module)
方法可以添加子模块。children()
和modules()
方法可以遍历模型的子模块。
钩子(Hooks):
register_forward_hook(hook)
和register_backward_hook(hook)
可以注册前向和后向传播的钩子,用于调试或可视化。
示例代码
以下是一个简单的 torch.nn.Module
示例,展示如何定义一个简单的全连接神经网络:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 128) # 输入层到隐藏层
self.fc2 = nn.Linear(128, 64) # 隐藏层到隐藏层
self.fc3 = nn.Linear(64, 10) # 隐藏层到输出层
def forward(self, x):
x = F.relu(self.fc1(x)) # 激活函数
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建模型实例
model = SimpleNet()
# 打印模型结构
print(model)
# 前向传播
input_data = torch.randn(1, 784) # 随机输入数据
output = model(input_data)
print(output)
关键方法
forward(self, *input)
: 定义模型的前向传播逻辑。parameters()
: 返回模型的所有参数。state_dict()
: 返回模型的状态字典,包含所有参数和缓冲区。load_state_dict(state_dict)
: 加载模型的状态字典。to(device)
: 将模型移动到指定的设备。train()
和eval()
: 设置模型为训练模式或评估模式。
总结
torch.nn.Module
是 PyTorch 中构建神经网络模型的核心类。通过继承 torch.nn.Module
并实现 forward
方法,可以轻松定义和训练复杂的神经网络模型。torch.nn.Module
提供了丰富的功能,使得模型的定义、训练、保存和加载变得非常方便。