在PyTorch中,torch.nn.Module模块中的state_dict函数获得的一个字典变量,其存放训练过程中需要学习的权重和偏执系数,如下所示:
代码1:
import torch.nn as nn
module = nn.Linear(2, 2)
print(module.state_dict().keys())
代码2:
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
module = MyModel()
print(module.state_dict().keys())
参考
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict