数以百万计的 32x32 图像?听起来完全像 CIFAR。查看TensorFlow 模型,他们有一个脚本可以下载 CIFAR10 并将其转换为 TFRecords:download_and_convert_data.py。如果您的数据不是 CIFAR,请检查代码,它可能会对您有所帮助。
加载 CIFAR10 的代码如下所示:
with tf.Graph().as_default():
image_placeholder = tf.placeholder(dtype=tf.uint8)
encoded_image = tf.image.encode_png(image_placeholder)
with tf.Session('') as sess:
for j in range(num_images):
[...] # load image and label from disk
image = [...]
label = [...]
png_string = sess.run(encoded_image,
feed_dict={image_placeholder: image})
example = dataset_utils.image_to_tfexample(
png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label)
tfrecord_writer.write(example.SerializeToString())
[...]
该image_to_tfexample()
函数如下所示:
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
int_64_feature()
函数看起来是这样的(函数bytes_feature()
类似):
def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
编辑
更多细节:
像这样创建(这TFRecordWriter
也创建了文件):
with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
[...] # use the tfrecord_writer
的文档tf.image.encode_png()
说图像应该具有 shape [height, width, channels]
,其中channels = 1
用于灰度,channels = 2
用于灰度 + alpha,3 用于 RGB 颜色,以及channels = 4
用于 RGB 颜色 + alpha (RGBA)。