我最近正在做一个基于 tensorflow CNN 的项目,MNIST 数据集带有服务器接口。
在预测部分,我使用tf.argmax()来获得最大的 logit,这将是预测值。但是,它返回的值似乎不是正确的答案。
predict 函数大概是这样的:
self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
self._create_model()
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('../checkpoints/')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
pred = tf.nn.softmax(self.logits)
prediction = tf.argmax(pred, 1)
logit = sess.run(pred)
result = sess.run(prediction)[0]
print(logit)
print(result)
return result
结果是:
127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
1
如您所见,logits 显示最大编号的索引是5,但是tf.argmax()给了我1。
顺便说一句,我的模型是基本的 MNIST CNN 模型,您可以在链接中看到。
那么这个tf.argmax()函数发生了什么,或者我的代码有问题?