0

我有一个使用 tf.data 的训练管道。在数据集中有一些坏元素,在我的例子中值为 0。如何根据它们的值删除这些坏数据元素?由于数据集很大,我希望能够在训练时在管道中删除它们。

假设从以下伪代码:

def parse_function(element):
    height = element['height']
    if height <= 0: skip() #How to skip this value

    labels = element['label']
    features['height'] = height

    return features, labels

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)

建议是根据特征值使用 ds.skip(1),还是提供某种中性的重量/损失?

4

2 回答 2

1

您可以使用tf.data.Dataset.filter

def filter_func(elem):
    """ return True if the element is to be kept """
    return tf.math.greater(elem['height'],0)

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.filter(filter_func)
于 2021-04-07T12:47:40.863 回答
0

假设这element是您的代码中的一个数据框,那么它将是:

def parse_function(element):
    element = element.query('height>0')

    labels = element['label']
    features['height'] = element['height']

    return features, labels

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)

`

于 2021-04-07T12:50:43.110 回答