torch.nn.ParameterList 是 PyTorch 中的一个容器类,用于存储和管理一组 torch.nn.Parameter 对象。ParameterList 类似于 Python 的列表(list),但它专门用于存储 Parameter 对象,并且这些 Parameter 对象会被自动注册到模型的参数列表中,从而可以在训练过程中被优化器更新。

主要特点

  1. 自动注册参数ParameterList 中的 Parameter 对象会自动注册到模型的参数列表中,因此它们会被优化器识别并更新。
  2. 动态扩展ParameterList 可以像普通的 Python 列表一样动态地添加或删除 Parameter 对象。
  3. 索引访问:可以通过索引访问 ParameterList 中的 Parameter 对象。

使用示例

以下是一个简单的示例,展示了如何使用 torch.nn.ParameterList

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 创建一个 ParameterList 并添加一些 Parameter 对象
        self.params = nn.ParameterList([
            nn.Parameter(torch.randn(2, 2)),
            nn.Parameter(torch.randn(3, 3))
        ])
        
        # 动态添加一个新的 Parameter
        self.params.append(nn.Parameter(torch.randn(4, 4)))

    def forward(self, x):
        # 使用 ParameterList 中的参数
        for param in self.params:
            x = torch.matmul(x, param)
        return x

# 创建模型实例
model = MyModel()

# 打印模型的参数
for name, param in model.named_parameters():
    print(f"{name}: {param}")

# 前向传播
input_tensor = torch.randn(1, 2)
output = model(input_tensor)
print(output)

解释

  1. 初始化:在 MyModel__init__ 方法中,我们创建了一个 ParameterList 并添加了两个 Parameter 对象。然后,我们动态地添加了第三个 Parameter
  2. 前向传播:在 forward 方法中,我们遍历 ParameterList 中的所有 Parameter 对象,并将它们与输入 x 进行矩阵乘法操作。
  3. 参数注册ParameterList 中的 Parameter 对象会自动注册到模型的参数列表中,因此它们会被优化器识别并更新。

注意事项

  • ParameterList 中的 Parameter 对象必须是 torch.nn.Parameter 类型,不能是普通的张量。
  • ParameterList 本身不是一个 Parameter,因此它不会被优化器更新。

适用场景

  • 当你需要动态地管理一组 Parameter 对象时,ParameterList 是一个非常有用的工具。例如,在某些情况下,你可能需要在模型的不同层中使用不同数量的参数,这时 ParameterList 可以帮助你灵活地管理这些参数。