0

获得一个已构建的 tensorflow 数据集对象 ( tf.data.Dataset),名为data.

有没有办法通过检查数据来知道函数repeat//是否在这个对象上被调用batch?(并可能获得其他信息,例如批处理和重复的参数)shuffle

(我假设急切执行)

编辑 1:似乎 str 方法带有一些信息。调查那个。

编辑 2:属性 output_shapes 提供有关批量大小和形状的信息。

4

1 回答 1

0

我能想到的唯一解决方案是进入 tensorflow 代码。gen_dataset_ops.py是在源码构建过程中生成的,所以只能在本地找到。

另一个文件是dataset_ops.py,它可以在下面的链接中找到。您只需在相关函数的返回之前插入打印语句。例如 shuffle 函数来自dataset_ops.py

def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
"""Randomly shuffles the elements of this dataset.
...
print('Dataset shuffled') #inserted print here
return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)

数据集对象被包装到DatasetV1Adapter中,因此您无法提前了解它。急切模式的唯一区别是它支持显式迭代,但是这样做会非常低效

array = np.random.rand(10)
dataset = tf.data.Dataset.from_tensor_slices(array)
if len([i for i in dataset]) != array.shape[0]:
    print('repeated')

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py

于 2019-03-28T18:20:51.210 回答