0

我使用以下架构创建了一个初始网络模型。

def create_model(env):

    dropout_prob = 0.8 #aggresive dropout regularization
    num_units = 256 #number of neurons in the hidden units
  
    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.input_shape))
    model.add(Dense(num_units))
    model.add(Activation('relu'))

    model.add(Dense(num_units))
    model.add(Dropout(dropout_prob))
    model.add(Activation('relu'))

    model.add(Dense(env.action_size))
    model.add(Activation('softmax'))
    print(model.summary())
    return model

然后我调用更新网络架构的 DQNAgent

dqn = DQNAgent(model=model, nb_actions=env.action_size, memory=memory,
               nb_steps_warmup=settings['train']['warm_up'], 
               target_model_update=settings['train']['update_rate'], policy=policy, enable_dueling_network=True)
dqn.compile(Adam(lr=settings['train']['learning_rate']), metrics=['mse'])

这样做会导致更新的网络架构 - 正如预期的那样。现在的问题是,当我尝试调用这个拟合的新网络时,原来的创建模型函数不能接受保存的模型权重,因为层架构根本不合适。

print(model.summary())

Layer (type)                 Output Shape              Param #   
=================================================================
flatten_49 (Flatten)         (None, 106)               0         
_________________________________________________________________
dense_147 (Dense)            (None, 128)               13696     
_________________________________________________________________
activation_145 (Activation)  (None, 128)               0         
_________________________________________________________________
dense_148 (Dense)            (None, 64)                8256      
_________________________________________________________________
dropout_49 (Dropout)         (None, 64)                0         
_________________________________________________________________
activation_146 (Activation)  (None, 64)                0         
_________________________________________________________________
dense_149 (Dense)            (None, 3)                 195       
_________________________________________________________________
activation_147 (Activation)  (None, 3)                 0         
=================================================================
Total params: 22,147
Trainable params: 22,147
Non-trainable params: 0

print(dqn.model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_49_input (InputLayer (None, 1, 1, 106)         0         
_________________________________________________________________
flatten_49 (Flatten)         (None, 106)               0         
_________________________________________________________________
dense_147 (Dense)            (None, 128)               13696     
_________________________________________________________________
activation_145 (Activation)  (None, 128)               0         
_________________________________________________________________
dense_148 (Dense)            (None, 64)                8256      
_________________________________________________________________
dropout_49 (Dropout)         (None, 64)                0         
_________________________________________________________________
activation_146 (Activation)  (None, 64)                0         
_________________________________________________________________
dense_149 (Dense)            (None, 3)                 195       
_________________________________________________________________
dense_150 (Dense)            (None, 4)                 16        
_________________________________________________________________
lambda_3 (Lambda)            (None, 3)                 0         
=================================================================
Total params: 22,163
Trainable params: 22,163
Non-trainable params: 0
_________________________________________________________________

因此,在不训练新的 dqn 的情况下,我需要找到一种方法来创建一个网络架构,该架构是在原始架构的基础上创建的,但会应用 dqn 模型更改。

4

1 回答 1

0

最好的方法是保存整个模型

dqn.model.save('xyz')

然后加载模型,而不仅仅是新函数中的权重,而不是创建原始模型并将其转换为决斗形式。

model = keras.models.load_model('xyz')
于 2020-12-05T13:37:11.313 回答