我正在训练一个网络去噪图像,为此我使用 CIFAR10 数据集。我正在尝试生成一个自定义损失函数,以便损失是 mse/classification_accuracy。鉴于我的网络接收 32x32(噪声)图像作为输入并预测 32x32(去噪)图像,我假设 y_pred 和 Y_true 将是 32x32 图像的数组。因此我的自定义损失函数如下所示:
def custom_loss():
def joint_optimized_loss(y_true, y_pred):
mse = K.mean(K.square(y_pred - y_true), axis=-1)
preds = classif_model.predict(y_pred)
correctPreds = 0
totPreds = 0
for pred in preds:
predictedClass = pred.index(max(pred))
totPreds += 1
if predictedClass == currentClass:
correctPreds += 1
classifAccuracy = correctPreds / totPreds
loss = mse / classifAccuracy
return loss
return joint_optimized_loss
myModel.compile(optimizer='adadelta', loss=custom_loss())
classif_model 是一个预训练模型,可将 CIFAR10 图像分类为 10 个类别之一。它接收一组 32x32 图像。
但是,当我运行我的代码时,出现以下错误:
回溯(最近一次通话最后):
文件“myCode.py”,第 94 行,在
myModel.compile(optimizer='adadelta', loss=custom_loss()) 文件“/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py”,第 850 行, 在编译
sample_weight,掩码)文件“/home/rvdalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py”,第450行,加权
score_array = fn(y_true, y_pred) 文件“myCode.py”,第 57 行,在joint_optimized_loss 中
preds = classif_model.predict(y_pred) 文件“/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/models.py”,第 913 行,在预测中
return self.model.predict(x, batch_size=batch_size, verbose=verbose) 文件“/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py”,行1713,在预测中
详细=详细,步骤=步骤)文件“/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py”,第 1260 行,在 _predict_loop
batches = _make_batches(num_samples, batch_size) 文件“/home/rvidalma/anaconda2/envs/tensorUpdated/lib/python2.7/site-packages/keras/engine/training.py”,第 374 行,在 _make_batches
num_batches = int(np.ceil(size / float(batch_size)))
AttributeError: 'Dimension' object has no attribute 'ceil'
我认为这与以下事实有关,y_true
并且y_pred
都是在训练之前为空的张量,因此classif_model.predict
失败,因为它期望一个数组。但是我不确定如何解决这个问题......
我尝试获取y_pred
using的值K.get_value(y_pred)
,但这给了我以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError:形状 [-1,32,32,3] 具有负尺寸 [[Node: input_1 = Placeholderdtype=DT_FLOAT, shape=[?,32,32,3], _device="/工作:本地主机/副本:0/任务:0/cpu:0“]]