2

在创建和加载 .tfrecord 文件的上下文中,我遇到了以下问题:

生成 dataset.tfrecord 文件

文件夹 /Batch_manager/assets 包含一些用于生成 dataset.tfrecord 文件的 *.tif 图像:

def _save_as_tfrecord(self, path, name):
    self.__filename = os.path.join(path, name + '.tfrecord')
    writer = tf.python_io.TFRecordWriter(self.__filename)
    print('Writing', self.__filename)
    for index, img in enumerate(self.load(get_iterator=True, n_images=1)):
        img = img[0]
        image_raw = img.tostring()
        rows = img.shape[0]
        cols = img.shape[1]
        try:
            depth = img.shape[2]
        except IndexError:
            depth = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': self._int64_feature(rows), 
            'width': self._int64_feature(cols), 
            'depth': self._int64_feature(depth), 
            'label': self._int64_feature(int(self.target[index])), 
            'image_raw': self._bytes_feature(image_raw)
                }))
        writer.write(example.SerializeToString())
    writer.close()

从 dataset.tfrecord 文件中读取

接下来,我尝试使用指向 dataset.tfrecord 文件的路径从该文件中读取:

def dataset_input_fn(self, path):
    dataset = tf.contrib.data.TFRecordDataset(path)

    def parser(record):
        keys_to_features = {
            "height": tf.FixedLenFeature((), tf.int64, default_value=""),
            "width": tf.FixedLenFeature((), tf.int64, default_value=""),
            "depth": tf.FixedLenFeature((), tf.int64, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64, default_value=""),
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        }
        print(record)
        features = tf.parse_single_example(record, features=keys_to_features)
        print(features)
        label = features['label']
        height = features['height']
        width = features['width']
        depth = features['depth']
        image = tf.decode_raw(features['image_raw'], tf.float32) 
        image = tf.reshape(image, [height, width, -1])
        label = tf.cast(features["label"], tf.int32)

        return {"image_raw": image, "height": height, "width": width, "depth":depth, "label":label}

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(32)
    iterator = dataset.make_one_shot_iterator()

    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features = iterator.get_next()

    return Features

错误信息:

TypeError: 预期 int64,得到了 '' 类型的 'str' 代替。

这段代码有什么问题?我成功验证了 dataset.tfrecord 实际上包含正确的图像和元数据!

4

1 回答 1

0

发生错误是因为我复制并粘贴了这个示例,该示例将所有键值对的值设置为空字符串,由default_value="". tf.FixedLenFeature从所有解决问题中删除它。

于 2017-10-09T09:23:58.743 回答