0

我有一个 tfrecords 文件,不幸的是其中没有标签值。它有两个值: ImageId

因此,要获取标签,我需要查看 pandas DataFrame 中的 Id 以驱动其值,然后根据其值创建标签,例如:

if df[df['id'] == Id]['value'] > threshold_value:
   label = 1
else:
   label = 0

但是,我不知道如何将Tensor("ParseSingleExample/ParseExample/ParseExampleV2:1", shape=(), dtype=string)转换为 python 字符串。

我在这里复制了解析 tfrecords 的代码:

def parse_tf_records(example_input):
    feature_description_dict = {
        IMAGE_FIELD: tf.io.FixedLenFeature(IMAGE_SIZE, tf.float32),
        ID_FIELD: tf.io.FixedLenFeature([], tf.string)
    }
    parsed_example = tf.io.parse_single_example(example_input, feature_description_dict)
    return parsed_example

def read_tfrecord(example_input):
    parsed_example = parse_tf_records(example_input)
    image_data = parsed_example[IMAGE_FIELD]
    id_data = parsed_example[ID_FIELD]
    
    # label = Look for the value of id_data in a Pandas Dataframe and compare the value to threshold_value
    label_data = tf.cast(label, tf.int32)
    return image_data , label_data

我正在使用张量流 2.4.1。如果有人可以帮助我,我真的很感激。谢谢。

4

1 回答 1

0

好的, tf.py_fction 就是答案。这是我的代码,它运行良好:

def get_label(tf_id):
    _id = tf_id.numpy().decode('utf-8')
    if df[df['id'] == _id]['value'] > threshold_value:
       label = 1
    else:
       label = 0
    return tf.cast(label, tf.int32)

def read_tfrecord(example_input):
    parsed_example = parse_tf_records(example_input)
    image_data = parsed_example[IMAGE_FIELD]
    id_data = parsed_example[ID_FIELD]
    label = tf.py_function(get_label, [id_data], tf.int32)
    return image_data , label
于 2021-04-13T04:59:02.260 回答