我一直在尝试使用 tfjs 建立一个简单的强化学习示例。但是,在尝试训练模型时,我遇到了以下错误:
Uncaught (in promise) Error: Error when checking target: expected dense_Dense5 to have shape [,1], but got array with shape [3,4]
我建立了如下模型:
const NUM_OUTPUTS = 4;
const model = tf.sequential();
//First hidden Layer, which also defines the input shape of the model
model.add(
tf.layers.dense({
units: LAYER_1_UNITS,
batchInputShape: [null, NUM_INPUTS],
activation: "relu",
})
);
// Second hidden Layer
model.add(tf.layers.dense({ units: LAYER_2_UNITS, activation: "relu" }));
// Third hidden Layer
model.add(tf.layers.dense({ units: LAYER_3_UNITS, activation: "relu" }));
// Fourth hidden Layer
model.add(tf.layers.dense({ units: LAYER_4_UNITS, activation: "relu" }));
// Defining the output Layer of the model
model.add(tf.layers.dense({ units: NUM_OUTPUTS, activation: "relu" }));
model.compile({
optimizer: tf.train.adam(),
loss: "sparseCategoricalCrossentropy",
metrics: "accuracy",
});
训练是由一个计算一些例子的 Q 值的函数完成的:
batch.forEach((sample) => {
const { state, nextState, action, reward } = sample;
// We let the model predict the rewards of the current state.
const current_Q: tf.Tensor = <tf.Tensor>model.predict(state);
// We also let the model predict the rewards for the next state, if there was a next state in the
//game.
let future_reward = tf.zeros([NUM_ACTIONS]);
if (nextState) {
future_reward = <Tensor>model.predict(nextState);
}
let totalValue =
reward + discountFactor * future_reward.max().dataSync()[0];
current_Q.bufferSync().set(totalValue, 0, action);
// We can now push the state to the input collector
x = x.concat(Array.from(state.dataSync()));
// For the labels/outputs, we push the updated Q values
y = y.concat(Array.from(current_Q.dataSync()));
});
await model.fit(
tf.tensor2d(x, [batch.length, NUM_INPUTS]),
tf.tensor2d(y, [batch.length, NUM_OUTPUTS]),
{
batchSize: batch.length,
epochs: 3,
}
);
这似乎是为 fit 函数提供示例的正确方法,因为在记录模型时,最后一个密集层的形状是正确的:
然而,它会导致上面显示的错误,而不是预期的形状 [3,4] 它检查形状 [,1]。我真的不明白这个形状是从哪里突然来的,非常感谢一些帮助!
为了获得更好的概览,您可以从其 Github 存储库中简单地查看/签出整个项目:
有问题的 tensorflow 代码位于 AI 文件夹中。
编辑:
y
提供模型摘要以及提供的张量形状的一些信息model.fit(x,y)
: