0

我在rllib文档中没有看到任何可以让我像print(model.summary())在 keras 中那样打印模型的快速摘要的内容。我尝试使用 tf-slim 和

variables = tf.compat.v1.model_variables()
slim.model_analyzer.analyze_vars(variables, print_info=True)

对张量流模型有一个粗略的了解,但是在模型初始化后没有发现任何变量(插入到ESTrainer类_init 的末尾)。具体来说,我一直在尝试获取进化策略 (ES) 策略的摘要,以验证模型配置的更改是否按预期更新,但我无法获得摘要打印工作。

有没有现成的方法呢?苗条有望在这里工作吗?

4

1 回答 1

1

训练代理可以返回允许您访问模型的策略:

agent = ppo.PPOTrainer(config, env=select_env)

policy = agent.get_policy()
policy.model.base_model.summary() # Prints the model summary

样本输出:

 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 observations (InputLayer)      [(None, 7)]          0           []                               
                                                                                                  
 fc_1 (Dense)                   (None, 256)          2048        ['observations[0][0]']           
                                                                                                  
 fc_value_1 (Dense)             (None, 256)          2048        ['observations[0][0]']           
                                                                                                  
 fc_2 (Dense)                   (None, 256)          65792       ['fc_1[0][0]']                   
                                                                                                  
 fc_value_2 (Dense)             (None, 256)          65792       ['fc_value_1[0][0]']             
                                                                                                  
 fc_out (Dense)                 (None, 5)            1285        ['fc_2[0][0]']                   
                                                                                                  
 value_out (Dense)              (None, 1)            257         ['fc_value_2[0][0]']             
                                                                                                  
==================================================================================================
Total params: 137,222
Trainable params: 137,222
Non-trainable params: 0
于 2022-01-04T08:20:56.190 回答