1

I am trying to understand the structure of the custom recurrent policy introduced in the documentation of the Stable Baselines:

How exactly is the Lstm NN constructed? (check code below)

From what I understood from the documentation: in this case net_arch=[8, 'lstm'] means, that before the LsTm there is a NN with hidden layers of size 8. A crude illustration would be:

observation (input) -> 8 hidden nodes -> Lstm -> action (output)

Let's say, I want to construct the following Network:

observation -> hidden layer of 8 nodes -> hidden layer of 16 nodes -> Lstm -> hidden layer of 16 nodes -> output layer (outputs: from policy and value network)

Would I have to write net_arch=[8,16, 'lstm',16] ? Is this correct? Also, what exactly does it mean feature_extractor='mlp'] ?

``class CustomLSTMPolicy(LstmPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64,reuse=False, **_kwargs):
    super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
                     net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])],
                     layer_norm=True, feature_extraction="mlp", **_kwargs)``
4

0 回答 0