1

我正在使用 RL4J(集成在 DeepLearning4J 中的强化学习框架)来让汽车在电子游戏的赛道上跑完一圈。

我使用以下代码在训练后保存模型:

QLearningDiscreteConv<ScreenFrameState> dql = new QLearningDiscreteConv(mdp, RACING_NET_CONFIG, RACING_HP, RACING_QL, manager);
dql.train();
dql.getNeuralNet().save(model);

模型保存后,我想看看它的行为,所以我加载它来播放它:

DQN load = DQN.load(model);
QLearningDiscreteConv<ScreenFrameState> dql = new QLearningDiscreteConv(mdp, load, RACING_HP, RACING_QL, manager);
dql.getPolicy().play(mdp);

但加载时失败并出现此错误:

org.deeplearning4j.exception.DL4JInvalidInputException: Cannot do forward pass in Convolution layer (layer name = layer0, layer index = 0): input array depth does not match CNN layer configuration (data input depth = 109, [minibatch,inputDepth,height,width]=[1, 109, 150, 3]; expected input depth = 10) (layer name: layer0, layer index: 0)
   at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.preOutput(ConvolutionLayer.java:294)
   at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.preOutput(ConvolutionLayer.java:248)
   at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.activate(ConvolutionLayer.java:392)
   at org.deeplearning4j.nn.layers.AbstractLayer.activate(AbstractLayer.java:309)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.activationFromPrevLayer(MultiLayerNetwork.java:789)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForwardToLayer(MultiLayerNetwork.java:929)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:870)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:861)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.silentOutput(MultiLayerNetwork.java:1906)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1898)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1871)
   at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1952)
   at org.deeplearning4j.rl4j.network.dqn.DQN.output(DQN.java:49)
   at org.deeplearning4j.rl4j.policy.DQNPolicy.nextAction(DQNPolicy.java:32)
   at org.deeplearning4j.rl4j.policy.DQNPolicy.nextAction(DQNPolicy.java:18)
   at org.deeplearning4j.rl4j.policy.Policy.play(Policy.java:72)
   at org.deeplearning4j.rl4j.policy.Policy.play(Policy.java:27)
   at me.andreaiacono.racinglearning.rl.QLearning.race(QLearning.java:81)
   at me.andreaiacono.racinglearning.core.player.QLearningPlayer.race(QLearningPlayer.java:19)
   at me.andreaiacono.racinglearning.gui.GameWorker.doInBackground(GameWorker.java:56)
   at me.andreaiacono.racinglearning.gui.GameWorker.doInBackground(GameWorker.java:11)
   at javax.swing.SwingWorker$1.call(SwingWorker.java:295)
   at java.util.concurrent.FutureTask.run(FutureTask.java:266)
   at javax.swing.SwingWorker.run(SwingWorker.java:334)
   at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
   at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
   at java.lang.Thread.run(Thread.java:748)

输入正确:我的屏幕是 150 * 109 像素,有 3 个颜色通道;为什么加载时期望大小为 10?我错过了什么?

谢谢,安德里亚

4

2 回答 2

0

您使用的版本是什么?有时,如果您使用快照存储库,可能会出现临时错误,但这些人会很快修复它。因此,此时您可能碰巧从快照中获取了代码。取稳定版。

于 2019-01-16T22:51:40.937 回答
0

(数据输入深度 = 109,[minibatch,inputDepth,height,width]=[1, 109, 150, 3];预期输入深度 = 10)

看起来你设置inputDepth109,而它应该设置为3(通道数)。我dl4j个人不熟悉,所以不知道为什么它会说“预期输入深度 = 10”,但我想你至少可以尝试切换你给出这些参数的顺序。

于 2018-04-09T08:42:51.760 回答