3

我正在使用以下代码生成 tfrecords 文件。

  def generate_tfrecords(data_path, labels, name):
      """Converts a dataset to tfrecords."""
      filename = os.path.join(args.tfrecords_path, name + '.tfrecords')
      writer = tf.python_io.TFRecordWriter(filename)
for index, data in enumerate(data_path):
      with tf.gfile.GFile(data, 'rb') as fid:
        encoded_jpg = fid.read()
        print(len(encoded_jpg))   # 17904
      encoded_jpg_io = io.BytesIO(encoded_jpg)
      image = pil.open(encoded_jpg_io)
      image = np.asarray(image)
      print(image.shape)    # 112*112*3
      example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(int(image.shape[0])),
        'width': _int64_feature(int(image.shape[1])),
        'depth': _int64_feature(int(3)),
        'label': _int64_feature(int(labels[index])),
        'image_raw': _bytes_feature(encoded_jpg)}))
      writer.write(example.SerializeToString())
    writer.close()

在上面的代码中,encoded_jpg有长度17904,并且图像的形状112*112*3不一致。

当我使用以下代码解析 tfrecords 时:

def _parse_function(example_proto):
     features = {'height':  tf.FixedLenFeature((), tf.int64, default_value=0),
           'width': tf.FixedLenFeature((), tf.int64, default_value=0),
           'depth': tf.FixedLenFeature((), tf.int64, default_value=0),
           'label': tf.FixedLenFeature((), tf.int64, default_value=0),
           'image_raw': tf.FixedLenFeature((), tf.string, default_value="")}
     parsed_features = tf.parse_single_example(example_proto, features)
     height = tf.cast(parsed_features["height"], tf.int32)  # 112
     width = tf.cast(parsed_features["width"], tf.int32)    # 112
     depth = tf.cast(parsed_features["depth"], tf.int32)   #3
     label = parsed_features['label']
     img = tf.decode_raw(parsed_features['image_raw'], tf.uint8, little_endian=True)
     img = tf.reshape(img, [height, width, depth])
     return img, label

当我使用上面的代码时,出现以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 17904 values, but the requested shape has 37632
 [[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw, Reshape/shape)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?,?,?], [?]], output_types=[DT_UINT8, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

我怎么能解决这个问题。图像类型为png, 和37632=112*112*3。谢谢!

4

1 回答 1

7

使用decode_jpeg而不是 decode_raw

于 2017-11-27T09:14:02.200 回答