0

我正在尝试在自定义数据集 [图像/标签] 上使用 TFX 所以我使用ImportExampleGen文件.TFRecord如下

example_gen = ImportExampleGen(input_base=TFRecord_DIR_PATH)

context.run(example_gen)

artifact = example_gen.outputs['examples'].get()[0]

我得到了IndexError: list index out of range,因为example_gen.outputs['examples'].get()[]

这是图像到 TFRecord 代码

...
for idx, d in enumerate(str_labels) # LABELS IS DIR NAME (STR):
    imgs = glob.glob(f"..\\PATH\\*.*g")
    str2int[d] = idx
    for img_path in tqdm.tqdm(imgs):
        image = cv2.imread(img_path)[:,:,::-1]
        all_imgs.append(cv2.resize(image, (144, 96)))
        labels.append(idx)

all_imgs = np.array(all_imgs)
labels = np.array(labels)[...,np.newaxis]

with tf.io.TFRecordWriter("TFRecord_DIR_PATH/NAME.tfrecord") as tfrecord:
  for lbl, img in zip(labels, all_imgs):
    label = lbl
    feature = tf.io.serialize_tensor(img)
    features = {
      "label" : tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
      "feature" : tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.numpy()]))
    }
    example = tf.train.Example(features=tf.train.Features(feature=features))
    tfrecord.write(example.SerializeToString())

我正在使用TF==2.7.1TFX==1.6.0

我找不到错误,所以我希望你能。谢谢你。

4

0 回答 0