我想使用 CNN 来解决去模糊任务,我有训练数据,它是一个 png 图像的目录和一个包含文件名的相应文本文件。
由于数据太大而无法一步添加到内存中,是否有任何 API 或某种方法可以使我可以将模糊图像作为输入读取,并将其作为预期结果进行训练?
我花了很多时间来解决这个问题,但是在阅读了在线 API 介绍中的 API 后,我感到困惑。
我想使用 CNN 来解决去模糊任务,我有训练数据,它是一个 png 图像的目录和一个包含文件名的相应文本文件。
由于数据太大而无法一步添加到内存中,是否有任何 API 或某种方法可以使我可以将模糊图像作为输入读取,并将其作为预期结果进行训练?
我花了很多时间来解决这个问题,但是在阅读了在线 API 介绍中的 API 后,我感到困惑。
该方法并不那么混乱。tensorflow 提供了 TFrecords 文件以充分利用内存。
def create_cord():
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index in xrange(66742):
blur_file_name = get_file_name(index, True)
orig_file_name = get_file_name(index, False)
blur_image_path = cwd + blur_file_name
orig_image_path = cwd + orig_file_name
blur_image = Image.open(blur_image_path)
orig_image = Image.open(orig_image_path)
blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
blur_image_raw = blur_image.tobytes()
orig_image_raw = orig_image.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
读取数据集:
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'blur_image_raw': tf.FixedLenFeature([], tf.string),
'orig_image_raw': tf.FixedLenFeature([], tf.string),
})
blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
blur_img = tf.cast(blur_img, tf.float32) * (1. / 255) - 0.5
orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
orig_img = tf.cast(orig_img, tf.float32) * (1. / 255) - 0.5
return blur_img, orig_img
if __name__ == '__main__':
# create_cord()
blur, orig = read_and_decode("train.tfrecords")
blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
batch_size=3, capacity=1000,
min_after_dequeue=100)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 启动队列
threads = tf.train.start_queue_runners(sess=sess)
for i in range(3):
v, l = sess.run([blur_batch, orig_batch])
print(v.shape, l.shape)