我正在尝试在 emnist 数据集上训练神经网络,但是当我尝试展平图像时,它会引发以下错误:
WARNING:tensorflow:Model 是用形状 (None, 28, 28) 构造的输入 Tensor("flatten_input:0", shape=(None, 28, 28), dtype=float32),但它是在不兼容的输入上调用的形状(无、1、28、28)。
我无法弄清楚似乎是什么问题,并尝试更改我的预处理,从我的 model.fit 和我的 ds.map 中删除批量大小。
这是完整的代码:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
def preprocess(dict):
image = dict['image']
image = tf.transpose(image)
label = dict['label']
return image, label
train_data, validation_data = tfds.load('emnist/letters', split = ['train', 'test'])
train_data_gen = train_data.map(preprocess).shuffle(1000).batch(32)
validation_data_gen = validation_data.map(preprocess).batch(32)
print(train_data_gen)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape = (28, 28)),
tf.keras.layers.Dense(128, activation = 'relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = ['accuracy'])
early_stopping = keras.callbacks.EarlyStopping(monitor = 'val_accuracy', patience = 10)
history = model.fit(train_data_gen, epochs = 50, batch_size = 32, validation_data = validation_data_gen, callbacks = [early_stopping], verbose = 1)
model.save('emnistmodel.h5')