我正在使用以下代码生成 tfrecords 文件。
def generate_tfrecords(data_path, labels, name):
"""Converts a dataset to tfrecords."""
filename = os.path.join(args.tfrecords_path, name + '.tfrecords')
writer = tf.python_io.TFRecordWriter(filename)
for index, data in enumerate(data_path):
with tf.gfile.GFile(data, 'rb') as fid:
encoded_jpg = fid.read()
print(len(encoded_jpg)) # 17904
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = pil.open(encoded_jpg_io)
image = np.asarray(image)
print(image.shape) # 112*112*3
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(int(image.shape[0])),
'width': _int64_feature(int(image.shape[1])),
'depth': _int64_feature(int(3)),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(encoded_jpg)}))
writer.write(example.SerializeToString())
writer.close()
在上面的代码中,encoded_jpg
有长度17904
,并且图像的形状112*112*3
不一致。
当我使用以下代码解析 tfrecords 时:
def _parse_function(example_proto):
features = {'height': tf.FixedLenFeature((), tf.int64, default_value=0),
'width': tf.FixedLenFeature((), tf.int64, default_value=0),
'depth': tf.FixedLenFeature((), tf.int64, default_value=0),
'label': tf.FixedLenFeature((), tf.int64, default_value=0),
'image_raw': tf.FixedLenFeature((), tf.string, default_value="")}
parsed_features = tf.parse_single_example(example_proto, features)
height = tf.cast(parsed_features["height"], tf.int32) # 112
width = tf.cast(parsed_features["width"], tf.int32) # 112
depth = tf.cast(parsed_features["depth"], tf.int32) #3
label = parsed_features['label']
img = tf.decode_raw(parsed_features['image_raw'], tf.uint8, little_endian=True)
img = tf.reshape(img, [height, width, depth])
return img, label
当我使用上面的代码时,出现以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 17904 values, but the requested shape has 37632
[[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw, Reshape/shape)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?,?,?], [?]], output_types=[DT_UINT8, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]
我怎么能解决这个问题。图像类型为png
, 和37632=112*112*3
。谢谢!