0

我在 Ubuntu 16.04 上使用 TensorFlow v2.4.1 和 CIFAR-10 数据集。

我的最终目标是修剪模型,但我收到与训练数据形状和训练标签相关的错误,该fit()方法显示:

InvalidArgumentError:logits 和标签必须具有相同的第一维,得到 logits 形状 [128,10] 和标签形状 [1280]

产生该错误的代码是:

model_for_pruning.fit(datasample.train_data, datasample.train_labels,
                    batch_size=batch_size, 
                    epochs=epochs, 
                    #validation_split=validation_split,
                    validation_data=(datasample.validation_data, 
                    datasample.validation_labels),
                    callbacks=callbacks)

数据集具有以下形状:

Training Data: datasample.train_data numpy.ndarray (45000, 32, 32, 3)
Training Labels: datasample.train_labels numpy.ndarray (45000, 10)

我相信该错误与输入的形状有关,但不知道如何解决。实际上,当我训练我的原始模型时,我就是这样做的:

model.fit(data.train_data, data.train_labels,
          batch_size=batch_size,
          validation_data=(data.validation_data, data.validation_labels),
          epochs=num_epochs,
          shuffle=True,
          callbacks=[tensorboard_callback])

模型: '''

model.add(Conv2D(params[0], (3, 3), input_shape=data.train_data.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(params[1], (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(params[2], (3, 3)))
model.add(Activation('relu'))
model.add(Conv2D(params[3], (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(params[4]))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(params[5]))
model.add(Activation('relu'))
model.add(Dense(10))

'''

然后: '''

def fn(correct, predicted):
return tf.nn.softmax_cross_entropy_with_logits(labels=correct,
                                               logits=predicted/train_temp)

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

model.compile(loss=fn,
          optimizer=sgd,
          metrics=['accuracy'])

model.fit(data.train_data, data.train_labels,
      batch_size=batch_size,
      validation_data=(data.validation_data, data.validation_labels),
      epochs=num_epochs,
      shuffle=True,
      callbacks=[tensorboard_callback])

'''

4

0 回答 0