1、model.state_dict()函数简介
state_dict是Python的字典对象(具体来说,是OrderedDict字典类型),可用于保存模型参数、超参数以及优化器(torch.optim)的状态信息。需要注意的是,只有具有可学习参数的层(如卷积层、线性层等)才有state_dict。
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.nn1 = nn.Linear(2, 3)
self.nn2 = nn.Linear(3, 6)
def forward(self, x):
x = F.relu(self.nn1(x))
return F.relu(self.nn2(x))
model = MyModel()
print(model.state_dict())
OrderedDict([
('nn1.weight', tensor([[-0.1838, -0.2477],[ 0.4845, 0.3157],[-0.5628, 0.3612]])),
('nn1.bias', tensor([-0.4328, -0.6779, 0.3845])),
('nn2.weight', tensor([[-5.0585e-01, -4.6973e-01, 1.6044e-02],[-3.4606e-01, 1.1130e-01, -2.0727e-01],
[-3.9844e-02, -4.2531e-01, 8.2558e-02],[ 3.3171e-02, -3.4334e-01, 4.5039e-01],
[-2.5320e-04, -5.2037e-01, 1.3504e-02],[-3.0776e-01, 8.9345e-02, -1.1076e-01]])),
('nn2.bias', tensor([ 0.1229, -0.2344, 0.0568, -0.3430, 0.2715, -0.3521]))
])
2、model.state_dict()作用
model.state_dict()函数的作用是保存模型,如下所示:
torch.save(model.state_dict(), 'model_weights.pth')
3、model.state_dict()与model.parameters()、model.named_parameters()的区别
首先,说说比较接近的model.parameters()和model.named_parameters()。这两者唯一的差别在于,named_parameters()返回的list中,每个元组打包了2个内容,分别是layer-name和layer-param,而parameters()只有后者。named_parameters()的函数如下所示:
def named_parameters(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Parameter]]:
model.state_dict()是将layer_name:layer_param
的键值信息存储为dict字典里面,而model.named_parameters()则是打包成一个元组然后再存到list当中;另外,model.state_dict()存储的是该model中包含的所有layer中的所有参数;而model.named_parameters()则只保存可学习、可被更新的参数,model.buffer()中的参数不包含在model.named_parameters()中。
补充说明:Pytorch中的缓冲区buffer
在深度学习中,参数(parameters)是神经网络模型的可学习的权重和偏置。模型的参数通常是模型的一部分,并且可以通过模型的状态字典进行访问和更新。而缓冲区(buffer)是与模型的参数有关的非学习的状态数据。
缓冲区对象是 Pytorch 中的 torch.nn.Module 类的成员之一。在模型定义过程中,我们可以通过使用self.register_buffer() 方法将缓冲区添加到模型中,并在模型中通过名称访问缓冲区。
下面是一个示例,说明如何在 Pytorch 中定义和使用缓冲区:
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.register_buffer('buffer', torch.zeros(3, 3))
model = MyModel()
print(model.buffer)
在上面的示例中,我们首先定义了一个名为 MyModel 的自定义模型类。在该类的构造函数中,我们通过调用 self.register_buffer() 方法向模型中添加了一个名为 buffer 的缓冲区。该缓冲区是一个大小为 3×3 的张量,并初始化为全零。最后,我们在模型实例上通过 model.buffer 的方式访问了缓冲区对象。
缓冲区的作用。缓冲区在深度学习中有多种应用场景,以下是一些缓冲区的常见用途:
1、保存运行统计信息。缓冲区可以用于保存模型训练过程中的运行统计信息,例如均值、标准差等。这些统计信息可以在模型的推理阶段用于归一化输入数据或其他预处理操作。
2、存储固定的张量。缓冲区可以用于存储固定的张量,例如预训练模型的权重或卷积核。这些固定的张量可以在模型的训练过程中保持不变,并在推理过程中使用。
3、缓存中间计算结果。在模型的前向传播过程中,缓冲区可以用于存储中间的计算结果,以便它们在后续的计算中被重用。这样可以提高计算效率,并减少计算的重复性。
4、保存模型相关的状态信息。缓冲区可以用于保存模型相关的状态信息,例如迭代次数、学习率等。这些状态信息可以在模型训练过程中进行更新,并用于优化算法的调整。
后记:读者推荐
如果想真正的学好PyTorch,极力推荐大家关注《PyTorch面试精华》