0

我想使用 CNN 来解决去模糊任务,我有训练数据,它是一个 png 图像的目录和一个包含文件名的相应文本文件。

由于数据太大而无法一步添加到内存中,是否有任何 API 或某种方法可以使我可以将模糊图像作为输入读取,并将其作为预期结果进行训练?

我花了很多时间来解决这个问题,但是在阅读了在线 API 介绍中的 API 后,我感到困惑。

4

1 回答 1

0

该方法并不那么混乱。tensorflow 提供了 TFrecords 文件以充分利用内存。

def create_cord():

    writer = tf.python_io.TFRecordWriter("train.tfrecords")
    for index in xrange(66742):
        blur_file_name = get_file_name(index, True)
        orig_file_name = get_file_name(index, False)
        blur_image_path = cwd + blur_file_name
        orig_image_path = cwd + orig_file_name

        blur_image = Image.open(blur_image_path)
        orig_image = Image.open(orig_image_path)

        blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
        orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))

        blur_image_raw = blur_image.tobytes()
        orig_image_raw = orig_image.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
        "blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
        'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
    }))
    writer.write(example.SerializeToString())
    writer.close()

读取数据集:

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                   features={
                                       'blur_image_raw':    tf.FixedLenFeature([], tf.string),
                                       'orig_image_raw': tf.FixedLenFeature([], tf.string),
                                   })

    blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    blur_img = tf.cast(blur_img, tf.float32) * (1. / 255) - 0.5

    orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    orig_img = tf.cast(orig_img, tf.float32) * (1. / 255) - 0.5

    return blur_img, orig_img


if __name__ == '__main__':

    #  create_cord()

    blur, orig = read_and_decode("train.tfrecords")
    blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
                                                batch_size=3, capacity=1000,
                                                min_after_dequeue=100)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
     # 启动队列
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(3):
            v, l = sess.run([blur_batch, orig_batch])
            print(v.shape, l.shape)
于 2017-03-04T09:06:47.720 回答