我使用卷积自编码神经网络方法训练我的模型然后保存,但是当我恢复我的模型以重建与训练图像相似的图像时,重建结果非常糟糕,损失很大。我不确定保存和读取文件是否有问题。
训练模型并保存!
#--------------------------------------------------------------------------
x = tf.placeholder(tf.float32, [None, dim], name = "X")
y = tf.placeholder(tf.float32, [None, dim], name = "Y")
keepprob = tf.placeholder(tf.float32, name = "K")
pred = cae(x, weights, biases, keepprob, imgsize)["out"]
cost = tf.reduce_sum(tf.square(cae(x, weights, biases, keepprob,imgsize)["out"] - tf.reshape(y, shape=[-1, imgsize, imgsize, 1])))
learning_rate = 0.01
optm = tf.train.AdamOptimizer(learning_rate).minimize(cost)
#--------------------------------------------------------------------------
sess = tf.Session()
save_model = os.path.join(PATH,'temp_saved_model')
saver = tf.train.Saver()
tf.add_to_collection("COST", cost)
tf.add_to_collection("PRED", pred)
sess.run(tf.global_variables_initializer())
mean_img = np.zeros((dim))
batch_size = 100
n_epochs = 1000
for epoch_i in range(n_epochs):
for batch_i in range(ntrain // batch_size):
trainbatch = np.array(train)
trainbatch = np.array([img - mean_img for img in trainbatch])
sess.run(optm, feed_dict={x: trainbatch, y: trainbatch, keepprob: 1.})
save_path = saver.save(sess, save_model)
print('Model saved in file: %s' %save_path)
sess.close()
恢复模型并尝试重建图像。
tf.reset_default_graph()
save_model = os.path.join(PATH + 'SaveModel/','temp_saved_model.meta')
imgsize = 64
dim = imgsize * imgsize
mean_img = np.zeros((dim))
with tf.Session() as sess:
saver = tf.train.import_meta_graph(save_model)
saver.restore(sess, tf.train.latest_checkpoint(PATH + 'SaveModel/'))
cost = tf.get_collection("COST")[0]
pred = tf.get_collection("PRED")[0]
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("X:0")
y = graph.get_tensor_by_name("Y:0")
k = graph.get_tensor_by_name("K:0")
for i in range(10):
test_xs = np.array(data)
test = load_image(test_xs, imgsize)
test = np.array([img - mean_img for img in test])
print ("[%02d/%02d] cost: %.4f" % (i, 10, sess.run(cost, feed_dict={x: test, y: test, K: 1.})))
训练过程中的损失值是1.321...,但是重建损失是16545.10441...我的代码有问题吗?