7

对于迁移学习,人们通常使用网络作为特征提取器来创建特征数据集,在该数据集上训练另一个分类器(例如 SVM)。

我想使用 Dataset API ( tf.contrib.data) 和dataset.map()

# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
    dataset = inputs(...)  # This creates a dataset of (image, label) pairs

    def map_example(image, label):
        features = feature_extractor(image, trainable=False)
        #  Leaving out initialization from a checkpoint here... 
        return features, label

    dataset = dataset.map(map_example)

    return dataset

为数据集创建迭代器时,这样做会失败。

ValueError: Cannot capture a stateful node by value.

这是真的,网络的内核和偏差是变量,因此是有状态的。对于这个特定的示例,它们不必是。

有没有办法让 Ops 和特别是tf.Variable对象无状态?

由于我正在使用tf.layers我不能简单地将它们创建为常量,并且设置trainable=False也不会创建常量,但不会将变量添加到GraphKeys.TRAINABLE_VARIABLES集合中。

4

1 回答 1

15

不幸的是,tf.Variable它本质上是有状态的。但是,仅当您使用创建迭代器时才会出现此错误Dataset.make_one_shot_iterator()。*为避免此问题,您可以改用Dataset.make_initializable_iterator(), ,但需要注意的是,在为输入管道中使用的对象运行初始化程序iterator.initializer,您还必须在返回的迭代器上运行。tf.Variable


* 此限制的原因是它用于封装数据集定义的Dataset.make_one_shot_iterator()TensorFlow 函数 ( ) 的实现细节和正在进行中的函数支持。Defun由于使用像查找表和变量这样的有状态资源比我们最初想象的更受欢迎,我们正在寻找放宽这种限制的方法。

于 2017-06-12T16:11:47.070 回答