2

一周前我刚刚开始使用 TensorFlow,我遇到了一些基本问题。

主要的一个是我没有找到一种方法来创建包含我所有数据的 TFRecords。我知道这个过程是必要的,以便用几百万张 32x32 像素的图像训练我自己的网络。

我发现很多教程和很多文档都提到了“input_pipeline”,但是这些教程都没有清楚地解释如何使用自己的图像创建自己的数据库。

我有几个主要文件夹和一些子文件夹,每个文件夹约 300,000 张 png 图像,其中标签位于图像名称中(0 或 1 - 二进制分类)。

获取这些图像的方法是通过 (glob) 行:

"/home/roishik/Desktop/database/train/exp*/*png"
"/home/roishik/Desktop/database/train/exp*/tot*/*png"

所以我的问题是:

如何创建包含这些图像及其标签的 TFRecords 文件?

我会非常感谢你的帮助!我被这个问题困扰了将近两天,我只找到了关于 MNIT 和 ImageNet 的非常具体的答案。

谢谢!

4

1 回答 1

4

数以百万计的 32x32 图像?听起来完全像 CIFAR。查看TensorFlow 模型,他们有一个脚本可以下载 CIFAR10 并将其转换为 TFRecords:download_and_convert_data.py。如果您的数据不是 CIFAR,请检查代码,它可能会对您有所帮助。

加载 CIFAR10 的代码如下所示:

with tf.Graph().as_default():
    image_placeholder = tf.placeholder(dtype=tf.uint8)
    encoded_image = tf.image.encode_png(image_placeholder)

    with tf.Session('') as sess:
        for j in range(num_images):
            [...] # load image and label from disk
            image = [...]
            label = [...]

            png_string = sess.run(encoded_image,
                                  feed_dict={image_placeholder: image})

            example = dataset_utils.image_to_tfexample(
                png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label)
            tfrecord_writer.write(example.SerializeToString())
            [...]

image_to_tfexample()函数如下所示:

def image_to_tfexample(image_data, image_format, height, width, class_id):
    return tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': bytes_feature(image_data),
        'image/format': bytes_feature(image_format),
        'image/class/label': int64_feature(class_id),
        'image/height': int64_feature(height),
        'image/width': int64_feature(width),
    }))

int_64_feature()函数看起来是这样的(函数bytes_feature()类似):

def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

编辑

更多细节:

  • 像这样创建(这TFRecordWriter也创建了文件):

    with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
        [...] # use the tfrecord_writer
    
  • 文档tf.image.encode_png()说图像应该具有 shape [height, width, channels],其中channels = 1用于灰度,channels = 2用于灰度 + alpha,3 用于 RGB 颜色,以及channels = 4用于 RGB 颜色 + alpha (RGBA)。

于 2016-09-13T07:45:02.670 回答