我已经尝试了两种用于转移学习的 tensorflow 模型。模型是
- inception_v3_2016_08_28.tar.gz - 来自tensorflow-models
- classify_image_graph_def.pb - 附带 tensorflow image_retraining 代码。
但我得到的结果完全不同。第二个模型比第一个模型表现得更好。是预期的吗?第一个模型的准确率为 57%,而第二个模型的准确率为 80%。
第一个模型是检查点文件。对于迁移学习,我已将检查点文件转换为 protobuf 文件。然后使用 tensorflow 附带的 python 代码 retrain.py 进行再训练。以下代码用于将检查点文件转换为 protobuf 文件。
checkpoint_file = '../check_points/inception_v3.ckpt'
decode_jpeg_data = tf.placeholder(tf.string)
decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3, dct_method="INTEGER_ACCURATE")
if decode_jpeg.dtype != tf.float32:
decode_jpeg = tf.image.convert_image_dtype(decode_jpeg, dtype=tf.float32)
image_ = tf.expand_dims(decode_jpeg, 0)
image = tf.image.resize_bicubic(image_, [299, 299], align_corners=True)
scaled_input_tensor = tf.scalar_mul((1.0/255), image)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
# loading the inception graph
arg_scope = inception_v3_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_v3(inputs=scaled_input_tensor, is_training=False, num_classes=1001)
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, checkpoint_file)
with gfile.FastGFile('./models/inceptionv3.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())