我有一堆格式类似于 Cifar10 的图像(二进制文件,size = 96*96*3
每个图像的字节数),一个接一个的图像(STL-10 数据集)。我打开的文件有 138MB。
我试图阅读并检查包含图像的张量的内容,以确保阅读正确,但是我有两个问题 -
- 是否
FixedLengthRecordReader
加载整个文件,但一次只提供一个输入?由于读取第一个size
字节应该相对较快。但是,代码运行大约需要两分钟。 - 如何以可显示的格式获取实际的图像内容,或在内部显示它们以验证图像是否被正确读取?我做了
sess.run(uint8image)
,但结果是空的。
代码如下:
import tensorflow as tf
def read_stl10(filename_queue):
class STL10Record(object):
pass
result = STL10Record()
result.height = 96
result.width = 96
result.depth = 3
image_bytes = result.height * result.width * result.depth
record_bytes = image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
print value
record_bytes = tf.decode_raw(value, tf.uint8)
depth_major = tf.reshape(tf.slice(record_bytes, [0], [image_bytes]),
[result.depth, result.height, result.width])
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
# probably a hack since I should've provided a string tensor
filename_queue = tf.train.string_input_producer(['./data/train_X'])
image = read_stl10(filename_queue)
print image.uint8image
with tf.Session() as sess:
result = sess.run(image.uint8image)
print result, type(result)
输出:
Tensor("ReaderRead:1", shape=TensorShape([]), dtype=string)
Tensor("transpose:0", shape=TensorShape([Dimension(96), Dimension(96), Dimension(3)]), dtype=uint8)
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
[empty line for last print]
Process finished with exit code 137
我在我的 CPU 上运行它,如果这增加了任何东西的话。
编辑:感谢 Rosa,我找到了纯 TensorFlow 解决方案。显然,在使用 时string_input_producer
,为了查看结果,您需要初始化队列运行器。唯一需要添加到上面代码的是下面的第二行:
...
with tf.Session() as sess:
tf.train.start_queue_runners(sess=sess)
...
之后,result
可以用 显示中的图像matplotlib.pyplot.imshow(result)
。我希望这可以帮助别人。如果您还有其他问题,请随时问我或查看 Rosa 答案中的链接。