I am trying to understand the structure of the custom recurrent policy introduced in the documentation of the Stable Baselines:
How exactly is the Lstm NN constructed? (check code below)
From what I understood from the documentation: in this case net_arch=[8, 'lstm']
means, that before the LsTm there is a NN with hidden layers of size 8.
A crude illustration would be:
observation (input) -> 8 hidden nodes -> Lstm -> action (output)
Let's say, I want to construct the following Network:
observation -> hidden layer of 8 nodes -> hidden layer of 16 nodes -> Lstm -> hidden layer of 16 nodes -> output layer (outputs: from policy and value network)
Would I have to write net_arch=[8,16, 'lstm',16]
? Is this correct? Also, what exactly does it mean feature_extractor='mlp']
?
``class CustomLSTMPolicy(LstmPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64,reuse=False, **_kwargs):
super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])],
layer_norm=True, feature_extraction="mlp", **_kwargs)``