2

我正在使用张量流 2.3.0

我有一个 python 数据生成器-

import tensorflow as tf
import numpy as np

vocab = [1,2,3,4,5]

def create_generator():
    'generates a random number from 0 to len(vocab)-1'
    count = 0
    while count < 4:
        x = np.random.randint(0, len(vocab))
        yield x
        count +=1

我将其设为 tf.data.Dataset 对象

gen = tf.data.Dataset.from_generator(create_generator, 
                                     args=[], 
                                     output_types=tf.int32, 
                                     output_shapes = (), )

现在我想使用map方法对项目进行子采样,这样 tf 生成器将永远不会输出任何偶数。

def subsample(x):
    'remove item if it is present in an even number [2,4]'
    
    '''
    #TODO
    '''
    return x
    
gen = gen.map(subsample)   

如何使用map方法实现这一点?

4

1 回答 1

5

很快不,您不能使用map. 映射函数对数据集的每个元素应用一些转换。您想要的是检查某些谓词的每个元素,并仅获取满足谓词的那些元素。

而那个功能是filter()

所以你可以这样做:

gen = gen.filter(lambda x: x % 2 != 0)

更新:

如果要使用自定义函数而不是lambda,可以执行以下操作:

def filter_func(x):
    if x**2 < 500:
        return True
    return False
gen = gen.filter(filter_func)

如果将此函数传递给filter平方小于 500 的所有数字,则将返回。

于 2020-11-17T05:14:04.857 回答