0

我创建了一个新的 Keras 模型,它有两个输入和一个输出,结合了一些层,这些层管理电影中的一些标题和描述,以执行一些测试。

这是 Keras 模型:

embedding_dim = 16

# first input model
inp1 = layers.Input(shape=(250,), name='Title')
emb1 = layers.Embedding(250, embedding_dim, name='embedding_1')(inp1)
avp1 = layers.GlobalAveragePooling1D(name="global_average_pooling1d_1")(emb1)
des1 = layers.Dense(64, activation='relu', name='dense_1')(avp1)
# second input model
inp2 = layers.Input(shape=(250,), name='Description')
emb2 = layers.Embedding(250, embedding_dim, name='embedding_2')(inp2)
avp2 = layers.GlobalAveragePooling1D(name="global_average_pooling1d_2")(emb2)
des2 = layers.Dense(64, activation='relu', name='dense_2')(avp2)
# merge input models
merge = layers.Concatenate(axis=1)([des1, des2])
drp = layers.Dropout(0.2)(merge)
out = layers.Dense(1, activation='sigmoid')(drp)
model = keras.Model(inputs=[inp1, inp2], outputs=out)
# summarize layers
print(model.summary())
# plot graph
utils.plot_model(model, to_file='multiple_inputs.png')

此外,我正在尝试使用两个数据集来拟合这个模型:

model.compile(loss='mae', optimizer='adam')
model.fit(x=[train_ds_1, train_ds_2], batch_size=16, epochs=100)

但我收到以下错误:

ValueError: Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.PrefetchDataset'>"}), <class 'NoneType'>

数据集是使用以下方法创建的:

 raw_train_ds = tf.keras.utils.text_dataset_from_directory(
  data_dir + "/" + "train" + "/" + feature, 
  batch_size=batch_size, 
  validation_split=0.2, 
  subset='training',
  labels="inferred",
  seed=42)

创建数据集后,有多个预处理步骤,但最后一步如下:

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)

该对象是 tensorflow.python.data.ops.dataset_ops.PrefetchDataset 的类型

有任何想法吗?

4

1 回答 1

0

文档显示 fit() 不支持数据集列表。您可以将数据集合并为一个concatenate

连接文档中的示例:

a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())

输出应该是:[1, 2, 3, 4, 5, 6, 7]

如果您真的想在 fit() 中使用数据集,您可以使用如下问题所示的字典:这可能会对您有所帮助

于 2022-02-16T13:03:28.043 回答