在 Stable-Baselines3 中,BaseFeaturesExtractor 是一个基类,用于从原始观测数据中提取特征,供强化学习模型使用。它通常用于处理高维或复杂的观测数据(如图像),将其转换为低维特征向量,便于模型处理。

主要作用

  • 特征提取:将原始观测数据(如图像)转换为低维特征向量。
  • 自定义网络:允许用户定义自己的特征提取网络,适应不同的观测数据。
  • 与策略网络集成:提取的特征会作为策略网络的输入,用于决策。

使用场景

  • 图像数据:使用卷积神经网络(CNN)提取图像特征。
  • 其他复杂数据:如LSTM或MLP处理时间序列或结构化数据。

示例代码

import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=256):
        super().__init__(observation_space, features_dim)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            sample = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample).shape[1]
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations):
        return self.linear(self.cnn(observations))

# 使用自定义特征提取器
from stable_baselines3 import PPO
from gym import spaces

observation_space = spaces.Box(low=0, high=255, shape=(3, 84, 84), dtype=np.uint8)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs={
    "features_extractor_class": CustomFeaturesExtractor,
    "features_extractor_kwargs": {"features_dim": 256},
})

总结

BaseFeaturesExtractor 是 Stable-Baselines3 中用于特征提取的基类,用户可以通过继承它自定义特征提取网络,适应不同的观测数据。