0

我正在构建一个 U-Net 模型。模型输入的形状为 (None, 4, 64, 80),输出层的形状为 (None, 7, 64, 80)。X_train 的形状为 (15993, 4, 64, 80),Y_train 的形状为 (15993, 64, 80)。我正在使用以下代码来拟合模型。

#6.fit the model 
checkpoint = ModelCheckpoint('model_train.hdf5', monitor='val_acc',verbose=1, save_best_only=True,mode='max')
callbacks_list = [checkpoint]
#train_Y_one_hot = to_categorical(Y_train)
model.fit(X_train, Y_train, validation_split=0.33, epochs=10, batch_size=100, callbacks=callbacks_list, verbose=0)

但它给了我以下错误

TypeError: Bad input argument to theano function with name "train_function" at index 1 (0-based). Wrong number of dimensions: expected 4, got 3 with shape (100, 64, 80).

谁能帮我解决错误。以下是模型摘要。

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_25 (InputLayer)           (None, 4, 64, 80)    0                                            
batch_normalization_12 (BatchNo (None, 4, 64, 80)    16          input_25[0][0]                   
conv2d_412 (Conv2D)             (None, 32, 64, 80)   1184        batch_normalization_12[0][0]     
conv2d_413 (Conv2D)             (None, 32, 64, 80)   9248        conv2d_412[0][0]                 
dropout_101 (Dropout)           (None, 32, 64, 80)   0           conv2d_413[0][0]                 
max_pooling2d_97 (MaxPooling2D) (None, 32, 32, 40)   0           dropout_101[0][0]                
conv2d_414 (Conv2D)             (None, 64, 32, 40)   18496       max_pooling2d_97[0][0]           
conv2d_415 (Conv2D)             (None, 64, 32, 40)   36928       conv2d_414[0][0]                 
dropout_102 (Dropout)           (None, 64, 32, 40)   0           conv2d_415[0][0]                 
max_pooling2d_98 (MaxPooling2D) (None, 64, 16, 20)   0           dropout_102[0][0]                
conv2d_416 (Conv2D)             (None, 128, 16, 20)  73856       max_pooling2d_98[0][0]           
conv2d_417 (Conv2D)             (None, 128, 16, 20)  147584      conv2d_416[0][0]                 
dropout_103 (Dropout)           (None, 128, 16, 20)  0           conv2d_417[0][0]                 
max_pooling2d_99 (MaxPooling2D) (None, 128, 8, 10)   0           dropout_103[0][0]                
conv2d_418 (Conv2D)             (None, 256, 8, 10)   295168      max_pooling2d_99[0][0]           
conv2d_419 (Conv2D)             (None, 256, 8, 10)   590080      conv2d_418[0][0]                 
dropout_104 (Dropout)           (None, 256, 8, 10)   0           conv2d_419[0][0]                 
max_pooling2d_100 (MaxPooling2D (None, 256, 4, 5)    0           dropout_104[0][0]                
conv2d_420 (Conv2D)             (None, 512, 4, 5)    1180160     max_pooling2d_100[0][0]          
conv2d_421 (Conv2D)             (None, 512, 4, 5)    2359808     conv2d_420[0][0]                 
dropout_105 (Dropout)           (None, 512, 4, 5)    0           conv2d_421[0][0]                 
up_sampling2d_56 (UpSampling2D) (None, 512, 8, 10)   0           dropout_105[0][0]                
merge_55 (Merge)                (None, 768, 8, 10)   0           up_sampling2d_56[0][0]           
                                                                 dropout_104[0][0]                
dropout_106 (Dropout)           (None, 768, 8, 10)   0           merge_55[0][0]                   
conv2d_422 (Conv2D)             (None, 256, 8, 10)   1769728     dropout_106[0][0]                
conv2d_423 (Conv2D)             (None, 256, 8, 10)   590080      conv2d_422[0][0]                 
up_sampling2d_57 (UpSampling2D) (None, 256, 16, 20)  0           conv2d_423[0][0]                 
merge_56 (Merge)                (None, 384, 16, 20)  0           up_sampling2d_57[0][0]           
                                                                 dropout_103[0][0]                
dropout_107 (Dropout)           (None, 384, 16, 20)  0           merge_56[0][0]                   
conv2d_424 (Conv2D)             (None, 128, 16, 20)  442496      dropout_107[0][0]                
conv2d_425 (Conv2D)             (None, 128, 16, 20)  147584      conv2d_424[0][0]                 
up_sampling2d_58 (UpSampling2D) (None, 128, 32, 40)  0           conv2d_425[0][0]                 
merge_57 (Merge)                (None, 192, 32, 40)  0           up_sampling2d_58[0][0]           
                                                                 dropout_102[0][0]                
dropout_108 (Dropout)           (None, 192, 32, 40)  0           merge_57[0][0]                   
conv2d_426 (Conv2D)             (None, 64, 32, 40)   110656      dropout_108[0][0]                
conv2d_427 (Conv2D)             (None, 64, 32, 40)   36928       conv2d_426[0][0]                 
up_sampling2d_59 (UpSampling2D) (None, 64, 64, 80)   0           conv2d_427[0][0]                 
merge_58 (Merge)                (None, 96, 64, 80)   0           up_sampling2d_59[0][0]           
                                                                 dropout_101[0][0]                
dropout_109 (Dropout)           (None, 96, 64, 80)   0           merge_58[0][0]                   
conv2d_428 (Conv2D)             (None, 32, 64, 80)   27680       dropout_109[0][0]                
conv2d_429 (Conv2D)             (None, 32, 64, 80)   9248        conv2d_428[0][0]                 
conv2d_430 (Conv2D)             (None, 7, 64, 80)    231         conv2d_429[0][0]                 
reshape_20 (Reshape)            (None, 7, 5120)      0           conv2d_430[0][0]                 
permute_20 (Permute)            (None, 5120, 7)      0           reshape_20[0][0]                 
activation_10 (Activation)      (None, 5120, 7)      0           permute_20[0][0]                 
permute_21 (Permute)            (None, 7, 5120)      0           activation_10[0][0]              
reshape_21 (Reshape)            (None, 7, 64, 80)    0           permute_21[0][0]                 
==================================================================================================
Total params: 7,847,159
Trainable params: 7,847,151
Non-trainable params: 8
_________________________________________________________________________________________________

4

1 回答 1

0

似乎 X_Train 和 Y_Train 的形状需要有四个维度,据我所知,它们只有三个对吗?(100,64,80)。

如果这不起作用,您可以尝试重塑为 (1,100,64,80),请尝试包含更多代码。

于 2018-08-07T06:32:07.763 回答