我正在创建一个 tensorflow 模型,该模型应该处理包含图像文件和其他分类输入的输入,两者都可以通过 csv 文件访问。首先我有这个模型,只有图像输入,但是当我用额外的元数据扩展输入时,我的model.fit
.
这是我的代码:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from keras.callbacks import History
df = pd.read_csv('dataset2.csv')
file_paths = df['file_name'].values
labels = df['label'].to_numpy(dtype='float32')
metadata = df[['usage', 'completion', 'height', 'construction']].to_numpy(dtype='float32')
dataset_size = len(df.index)
train_size = int(0.8 * dataset_size)
val_size = int(0.1 * dataset_size)
test_size = int(0.1 * dataset_size)
full_dataset_img = tf.data.Dataset.from_tensor_slices(file_paths)
ds_train_img = full_dataset_img.take(train_size)
ds_test_img = full_dataset_img.skip(train_size)
ds_val_img = ds_test_img.skip(val_size)
ds_test_img = ds_test_img.take(test_size)
labels = np.reshape(labels, (dataset_size, 1))
full_dataset_lbl = tf.data.Dataset.from_tensor_slices(labels)
ds_train_lbl = full_dataset_lbl.take(train_size)
ds_test_lbl = full_dataset_lbl.skip(train_size)
ds_val_lbl = ds_test_lbl.skip(val_size)
ds_test_lbl = ds_test_lbl.take(test_size)
metadata = np.reshape(metadata, (dataset_size, 4))
full_dataset_dat = tf.data.Dataset.from_tensor_slices(metadata)
ds_train_dat = full_dataset_dat.take(train_size)
ds_test_dat = full_dataset_dat.skip(train_size)
ds_val_dat = ds_test_dat.skip(val_size)
ds_test_dat = ds_test_dat.take(test_size)
# FUNCTION TO READ AND NORMALIZE THE IMAGES AND METADATA
def read_image(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, (350, 350))
return tf.cast(image, tf.float32) / 255.0
# FUNCTION FOR DATA AUGMENTATION
def augment(image):
if tf.random.uniform((), minval=0, maxval=1) < 0.1:
image = tf.tile(tf.image.rgb_to_grayscale(image), [1, 1, 3])
image = tf.image.random_brightness(image, max_delta=0.25)
image = tf.image.random_contrast(image, lower=0.75, upper=1.25)
image = tf.image.random_saturation(image, lower=0.75, upper=1.25)
image = tf.image.random_flip_left_right(image)
return image
ds_train_img = ds_train_img.map(read_image)
ds_train_img = ds_train_img.map(augment)
ds_val_img = ds_val_img.map(read_image)
ds_test_img = ds_test_img.map(read_image)
ds_train_lbl = np.array(ds_train_lbl)
ds_train_dat = np.array(ds_train_dat)
# DEFINING FUNCTIONAL MODEL FOR COMBINING IMAGE AND METADATA INPUTS
input_img = keras.Input(shape=(350, 350, 3))
input_dat = keras.Input(shape=(4,))
x = layers.Conv2D(16, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.02), padding='same')(input_img)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(32, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.02), padding='same')(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(64, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.02), padding='same')(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(128, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.02), padding='same')(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.MaxPooling2D()(x)
x = layers.Flatten()(x)
x = layers.Flatten()(x)
data = layers.Dense(128, activation='relu')(input_dat)
multi = layers.concatenate([x, data])
multi = layers.Dense(256, activation='relu')(multi)
multi = layers.Dropout(0.35)(multi)
output = layers.Dense(8, activation='sigmoid')(multi)
model = keras.Model(inputs=[input_img, input_dat], outputs=output)
model.compile(
optimizer=keras.optimizers.Adam(3e-5),
loss=[keras.losses.SparseCategoricalCrossentropy()],
metrics=["accuracy"])
model.fit(x=[ds_train_img, ds_train_dat], y=ds_train_lbl,
validation_data=([ds_val_img, ds_val_dat], ds_val_lbl),
epochs=30, batch_size=16, verbose=1)
当我想运行代码时,在 model.fit 中出现以下错误:
ValueError: Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.MapDataset'>", "<class 'numpy.ndarray'>"}), <class 'numpy.ndarray'>
任何帮助是极大的赞赏!