我正在尝试使用TF-SLIM dataset classes解析数据。
我将两个图像组合成一个 640x480x6 numpy 数组(因为我组合了两个图像的 RGB 通道)并将它们序列化以将它们保存到.tfrecords文件中。这是执行此操作的代码。
img_pair = combine_images(images[i][0],images[i][1])
img_flo = read_flo_file(labels[i][0])
height = img_pair.shape[0]
width = img_pair.shape[1]
img = img_pair.tostring()
flo = img_flo.tostring()
example = image_to_tfexample(
img, height, width, flo)
tfrecord_writer.write(example.SerializeToString())
def image_to_tfexample(image_data, height, width, flo):
return tf.train.Example(features=tf.train.Features(feature={
'image/img_pair': bytes_feature(image_data),
'image/flo': bytes_feature(flo),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
def combine_images(img1,img2):
img1 = np.array(Image.open(img1))
img2 = np.array(Image.open(img2))
return np.concatenate((img1,img2),axis=-1)
其中 img_pair 是一个640x480x6的numpy 数组,而 flo 是一个640x480x2的numpy 数组。
现在我想阅读这些例子。这是我到目前为止从 tf-slim flower.py (已更新以适合我)示例中获得的内容。
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/width': tf.FixedLenFeature([], tf.int64),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/img_pair': tf.FixedLenFeature([], tf.string),
'image/flo': tf.FixedLenFeature([], tf.string)
}
items_to_handlers = {
'image': slim.tfexample_decoder.Tensor('image/img_pair'),
'label': slim.tfexample_decoder.Tensor('image/flo'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)
这里的问题是image/img_pair 和 image/flo是二进制字符串。据我所知,它们首先需要转换为张量才能将它们作为 item_handlers 提供。
像这样。
items_to_handlers = {
'image': slim.tfexample_decoder.Tensor('image/img_pair'),
'label': slim.tfexample_decoder.Tensor('image/flo'),
}
我不知道如何将它解析回具有相同形状的张量,即 img_pair 为 640x480x6,flo 为 640x480x2。
我得到一个错误。
Will save model to /tmp/tfslim_model/
Traceback (most recent call last):
File "main.py", line 16, in <module>
images, _, labels = helpers.load_batch(dataset)
File "/home/muazzam/mywork/python/thesis/SceneflowTensorflow/new_stuff/helpers.py", line 36, in load_batch
common_queue_min=8)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/dataset_data_provider.py", line 97, in __init__
tensors = dataset.decoder.decode(data, items)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 424, in decode
outputs.append(handler.tensors_to_item(keys_to_tensors))
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 321, in tensors_to_item
return self._decode(image_buffer, image_format)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 350, in _decode
pred_fn_pairs, default=decode_image, exclusive=True)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3169, in case
case_seq = _build_case()
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3151, in _build_case
strict=strict, name="If_%d" % i)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 296, in new_func
return func(*args, **kwargs)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1819, in cond
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1694, in BuildCondBranch
original_result = fn()
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 338, in decode_image
return image_ops.decode_image(image_buffer, self._channels)
File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/image_ops_impl.py", line 1346, in decode_image
raise ValueError('channels must be in (None, 0, 1, 3, 4)')
ValueError: channels must be in (None, 0, 1, 3, 4)