I was trying to understand the policy networks in stable-baselines3 from this doc page.
As explained in this example, to specify custom CNN feature extractor, we extend
BaseFeaturesExtractor
class and specify it inpolicy_kwarg.features_extractor_class
with first paramCnnPolicy
:model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs)
Q1. Can we follow same approach for custom MLP feature extractor?
As explained in this example, to specify custom MLP feature extractor, we extend
ActorCriticPolicy
class and override_build_mlp_extractor()
and pass it as first param:class CustomActorCriticPolicy(ActorCriticPolicy): ... model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
Q2. Can we follow same approach for custom CNN feature extractor?
I feel either we can have CNN extractor or MLP extractor. So it makes no sense to pass
MlpPolicy
as first param to model and then specify CNN feature extractor inpolicy_kwarg.features_extractor_class
as in this example. This result in following policy (containing bothfeatures_extractor
andmlp_extractor
), which I feel is incorrect:ActorCriticPolicy( (features_extractor): Net( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (fc3): Linear(in_features=384, out_features=512, bias=True) ) (mlp_extractor): MlpExtractor( (shared_net): Sequential( (0): Linear(in_features=512, out_features=64, bias=True) (1): ReLU() ) (policy_net): Sequential( (0): Linear(in_features=64, out_features=32, bias=True) (1): ReLU() (2): Linear(in_features=32, out_features=16, bias=True) (3): ReLU() ) (value_net): Sequential( (0): Linear(in_features=64, out_features=32, bias=True) (1): ReLU() (2): Linear(in_features=32, out_features=16, bias=True) (3): ReLU() ) ) (action_net): Linear(in_features=16, out_features=7, bias=True) (value_net): Linear(in_features=16, out_features=1, bias=True) )
Q3. Am I correct with this understanding? If yes, then is one of the MLP or CNN feature extractor ignored?