我正在构建一个 U-Net 模型。模型输入的形状为 (None, 4, 64, 80),输出层的形状为 (None, 7, 64, 80)。X_train 的形状为 (15993, 4, 64, 80),Y_train 的形状为 (15993, 64, 80)。我正在使用以下代码来拟合模型。
#6.fit the model
checkpoint = ModelCheckpoint('model_train.hdf5', monitor='val_acc',verbose=1, save_best_only=True,mode='max')
callbacks_list = [checkpoint]
#train_Y_one_hot = to_categorical(Y_train)
model.fit(X_train, Y_train, validation_split=0.33, epochs=10, batch_size=100, callbacks=callbacks_list, verbose=0)
但它给了我以下错误
TypeError: Bad input argument to theano function with name "train_function" at index 1 (0-based). Wrong number of dimensions: expected 4, got 3 with shape (100, 64, 80).
谁能帮我解决错误。以下是模型摘要。
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_25 (InputLayer) (None, 4, 64, 80) 0
batch_normalization_12 (BatchNo (None, 4, 64, 80) 16 input_25[0][0]
conv2d_412 (Conv2D) (None, 32, 64, 80) 1184 batch_normalization_12[0][0]
conv2d_413 (Conv2D) (None, 32, 64, 80) 9248 conv2d_412[0][0]
dropout_101 (Dropout) (None, 32, 64, 80) 0 conv2d_413[0][0]
max_pooling2d_97 (MaxPooling2D) (None, 32, 32, 40) 0 dropout_101[0][0]
conv2d_414 (Conv2D) (None, 64, 32, 40) 18496 max_pooling2d_97[0][0]
conv2d_415 (Conv2D) (None, 64, 32, 40) 36928 conv2d_414[0][0]
dropout_102 (Dropout) (None, 64, 32, 40) 0 conv2d_415[0][0]
max_pooling2d_98 (MaxPooling2D) (None, 64, 16, 20) 0 dropout_102[0][0]
conv2d_416 (Conv2D) (None, 128, 16, 20) 73856 max_pooling2d_98[0][0]
conv2d_417 (Conv2D) (None, 128, 16, 20) 147584 conv2d_416[0][0]
dropout_103 (Dropout) (None, 128, 16, 20) 0 conv2d_417[0][0]
max_pooling2d_99 (MaxPooling2D) (None, 128, 8, 10) 0 dropout_103[0][0]
conv2d_418 (Conv2D) (None, 256, 8, 10) 295168 max_pooling2d_99[0][0]
conv2d_419 (Conv2D) (None, 256, 8, 10) 590080 conv2d_418[0][0]
dropout_104 (Dropout) (None, 256, 8, 10) 0 conv2d_419[0][0]
max_pooling2d_100 (MaxPooling2D (None, 256, 4, 5) 0 dropout_104[0][0]
conv2d_420 (Conv2D) (None, 512, 4, 5) 1180160 max_pooling2d_100[0][0]
conv2d_421 (Conv2D) (None, 512, 4, 5) 2359808 conv2d_420[0][0]
dropout_105 (Dropout) (None, 512, 4, 5) 0 conv2d_421[0][0]
up_sampling2d_56 (UpSampling2D) (None, 512, 8, 10) 0 dropout_105[0][0]
merge_55 (Merge) (None, 768, 8, 10) 0 up_sampling2d_56[0][0]
dropout_104[0][0]
dropout_106 (Dropout) (None, 768, 8, 10) 0 merge_55[0][0]
conv2d_422 (Conv2D) (None, 256, 8, 10) 1769728 dropout_106[0][0]
conv2d_423 (Conv2D) (None, 256, 8, 10) 590080 conv2d_422[0][0]
up_sampling2d_57 (UpSampling2D) (None, 256, 16, 20) 0 conv2d_423[0][0]
merge_56 (Merge) (None, 384, 16, 20) 0 up_sampling2d_57[0][0]
dropout_103[0][0]
dropout_107 (Dropout) (None, 384, 16, 20) 0 merge_56[0][0]
conv2d_424 (Conv2D) (None, 128, 16, 20) 442496 dropout_107[0][0]
conv2d_425 (Conv2D) (None, 128, 16, 20) 147584 conv2d_424[0][0]
up_sampling2d_58 (UpSampling2D) (None, 128, 32, 40) 0 conv2d_425[0][0]
merge_57 (Merge) (None, 192, 32, 40) 0 up_sampling2d_58[0][0]
dropout_102[0][0]
dropout_108 (Dropout) (None, 192, 32, 40) 0 merge_57[0][0]
conv2d_426 (Conv2D) (None, 64, 32, 40) 110656 dropout_108[0][0]
conv2d_427 (Conv2D) (None, 64, 32, 40) 36928 conv2d_426[0][0]
up_sampling2d_59 (UpSampling2D) (None, 64, 64, 80) 0 conv2d_427[0][0]
merge_58 (Merge) (None, 96, 64, 80) 0 up_sampling2d_59[0][0]
dropout_101[0][0]
dropout_109 (Dropout) (None, 96, 64, 80) 0 merge_58[0][0]
conv2d_428 (Conv2D) (None, 32, 64, 80) 27680 dropout_109[0][0]
conv2d_429 (Conv2D) (None, 32, 64, 80) 9248 conv2d_428[0][0]
conv2d_430 (Conv2D) (None, 7, 64, 80) 231 conv2d_429[0][0]
reshape_20 (Reshape) (None, 7, 5120) 0 conv2d_430[0][0]
permute_20 (Permute) (None, 5120, 7) 0 reshape_20[0][0]
activation_10 (Activation) (None, 5120, 7) 0 permute_20[0][0]
permute_21 (Permute) (None, 7, 5120) 0 activation_10[0][0]
reshape_21 (Reshape) (None, 7, 64, 80) 0 permute_21[0][0]
==================================================================================================
Total params: 7,847,159
Trainable params: 7,847,151
Non-trainable params: 8
_________________________________________________________________________________________________