torch.optim.Optimizer
是 PyTorch 中用于实现各种优化算法的基类。优化器的作用是根据计算出的梯度更新模型的参数,以最小化损失函数。PyTorch 提供了多种优化器,如 SGD、Adam、RMSprop 等,这些优化器都继承自 torch.optim.Optimizer
类。
1. 基本用法
要使用优化器,首先需要创建一个优化器实例,并将模型的参数传递给它。然后,在训练循环中,调用优化器的 step()
方法来更新模型的参数。
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的模型
model = nn.Linear(10, 1)
# 定义损失函数
criterion = nn.MSELoss()
# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练循环
for epoch in range(100):
# 前向传播
outputs = model(torch.randn(10))
loss = criterion(outputs, torch.randn(1))
# 反向传播
optimizer.zero_grad() # 清除之前的梯度
loss.backward() # 计算梯度
# 更新参数
optimizer.step() # 更新模型参数
2. 主要方法
zero_grad()
: 清除所有被优化参数的梯度。在每次更新参数之前调用,以避免梯度累积。step()
: 执行一次参数更新。根据计算出的梯度更新模型的参数。state_dict()
: 返回优化器的状态字典,包含优化器的所有状态信息。可以用于保存和加载优化器的状态。load_state_dict(state_dict)
: 从状态字典加载优化器的状态。通常用于恢复训练。
3. 常用优化器
PyTorch 提供了多种优化器,以下是一些常用的优化器:
SGD (随机梯度下降):
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
Adam (自适应矩估计):
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
RMSprop:
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)
Adagrad:
optimizer = optim.Adagrad(model.parameters(), lr=0.01)
Adadelta:
optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.9)
4. 参数组
torch.optim.Optimizer
支持参数组的概念,允许为不同的参数设置不同的超参数(如学习率)。这在微调预训练模型时非常有用。
optimizer = optim.SGD([
{'params': model.base.parameters(), 'lr': 0.001},
{'params': model.classifier.parameters(), 'lr': 0.01}
], momentum=0.9)
5. 学习率调度器
PyTorch 还提供了学习率调度器 (torch.optim.lr_scheduler
),用于在训练过程中动态调整学习率。
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
# 训练代码...
scheduler.step() # 更新学习率
6. 自定义优化器
如果需要实现自定义的优化算法,可以继承 torch.optim.Optimizer
并实现 step()
方法。
class CustomOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.01):
defaults = dict(lr=lr)
super(CustomOptimizer, self).__init__(params, defaults)
def step(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.data.add_(-group['lr'], p.grad.data)
7. 总结
torch.optim.Optimizer
是 PyTorch 中用于优化模型参数的核心类。通过使用不同的优化器和学习率调度器,可以有效地训练深度学习模型。理解优化器的工作原理和使用方法对于构建和训练神经网络模型至关重要。