0

我想将一个 numpy 数组输入 CNN,其中包含 2 个国际象棋位置,一个在移动之前,第二个在某个移动之后。我想训练 CNN 来估计传统国际象棋程序对这一步的评估。这些评估是 int 值。

x和的形状y是:x: (2000000, 8, 8, 2) , y: (2000000,)

型号代码:

#define model
model = Sequential()
#model.add(Dense(1024, activation='relu', input_dim=864))

model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu', input_shape=(8,8,2)))
model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu'))

model.add(Dense(128, activation='relu', init='uniform'))
model.add(BatchNormalization())

model.add(Dense(1))

model.compile(loss='mean_squared_error', optimizer='adam',metrics=['mae'])
print(model.summary())

培训完成:

history = model.fit(x, y, validation_split=0.1, epochs=5, batch_size=20000, verbose=2)

它给了我以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-7-619de3f1be1b> in <module>()
    171         for i in range(5):
    172             print("Fitting begins", x.shape, y.shape)
--> 173             history = model.fit(x, y, validation_split=0.1, epochs=5, batch_size=20000, verbose=2)
    174             #score = model.evaluate(x, y, verbose=2)
    175             #print(score)

/usr/local/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
    950             sample_weight=sample_weight,
    951             class_weight=class_weight,
--> 952             batch_size=batch_size)
    953         # Prepare validation data.
    954         do_validation = False

/usr/local/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    787                 feed_output_shapes,
    788                 check_batch_axis=False,  # Don't enforce the batch size.
--> 789                 exception_prefix='target')
    790 
    791             # Generate sample-wise weight values given the `sample_weight` and

/usr/local/lib/python3.6/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    126                         ': expected ' + names[i] + ' to have ' +
    127                         str(len(shape)) + ' dimensions, but got array '
--> 128                         'with shape ' + str(data_shape))
    129                 if not check_batch_axis:
    130                     data_shape = data_shape[1:]

ValueError: Error when checking target: expected dense_10 to have 4 dimensions, but got array with shape (2000000, 1)

我做错了什么?我怎样才能解决这个问题?


好的,我意识到问题与最后一层的输出形状有关:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_13 (Conv2D)           (None, 6, 6, 128)         2432      
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 4, 4, 128)         147584    
_________________________________________________________________
dense_14 (Dense)             (None, 4, 4, 128)         16512     
_________________________________________________________________
batch_normalization_7 (Batch (None, 4, 4, 128)         512       
_________________________________________________________________
dense_15 (Dense)             (None, 4, 4, 1)           129       
=================================================================

但这是为什么呢(None, 4, 4, 1)?不应该(None, 1)吗?它是一个值为 1 的单个神经元!

4

1 回答 1

0

但为什么是(无、4、4、1)?不应该是(无,1)???

不,不应该。因为Dense 层应用在其 input 的最后一个轴上,因此在这种情况下,它应用在Conv2Dlayer 的输出上,即 4D 张量,所以该Dense层的输出也将是 4D 张量。要解决此问题,您可以先Conv2D使用Flatten图层展平图层的输出,然后再使用该Dense图层,如下所示:

model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu'))
model.add(Flatten())  # flatten the output of `Conv2D` to a 2D tensor
model.add(Dense(128, activation='relu', init='uniform'))
于 2018-11-28T11:49:50.247 回答