在 Stable-Baselines3 中,BaseFeaturesExtractor 是一个基类,用于从原始观测数据中提取特征,供强化学习模型使用。它通常用于处理高维或复杂的观测数据(如图像),将其转换为低维特征向量,便于模型处理。
主要作用
- 特征提取:将原始观测数据(如图像)转换为低维特征向量。
- 自定义网络:允许用户定义自己的特征提取网络,适应不同的观测数据。
- 与策略网络集成:提取的特征会作为策略网络的输入,用于决策。
使用场景
- 图像数据:使用卷积神经网络(CNN)提取图像特征。
- 其他复杂数据:如LSTM或MLP处理时间序列或结构化数据。
示例代码
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomFeaturesExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim=256):
super().__init__(observation_space, features_dim)
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
with torch.no_grad():
sample = torch.as_tensor(observation_space.sample()[None]).float()
n_flatten = self.cnn(sample).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations):
return self.linear(self.cnn(observations))
# 使用自定义特征提取器
from stable_baselines3 import PPO
from gym import spaces
observation_space = spaces.Box(low=0, high=255, shape=(3, 84, 84), dtype=np.uint8)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs={
"features_extractor_class": CustomFeaturesExtractor,
"features_extractor_kwargs": {"features_dim": 256},
})
总结
BaseFeaturesExtractor 是 Stable-Baselines3 中用于特征提取的基类,用户可以通过继承它自定义特征提取网络,适应不同的观测数据。