1

我正在尝试使用带有 tf.keras 的 fit() 方法训练我的模型,因为输入数据来自hdf5文件,所以我将参数shuffle='batch'传递给fit()方法。但在第一个 epoch 结束后,出现以下错误:

TypeError: TypeError while preparing batch. If using HDF5 input data, pass shuffle="batch".

这是我的 fit() 方法:

model.fit(
    x=features_train,
    y=topics_train,
    batch_size=16384,
    epochs=35,
    callbacks=create_callbacks(),
    validation_data=(features_val, topics_val),
    shuffle='batch'
)

变量features_trainfeatures_val取自 hdf5 文件。

4

1 回答 1

1

features_val通过将其转换为 numpy 数组来解决它。

features_val_arr = np.array(features_val)

model.fit(
    x=features_train,
    y=topics_train,
    batch_size=16384,
    epochs=35,
    callbacks=create_callbacks(),
    validation_data=(features_val_arr, topics_val),
    shuffle='batch'
)
于 2019-04-14T08:11:56.233 回答