对于迁移学习,人们通常使用网络作为特征提取器来创建特征数据集,在该数据集上训练另一个分类器(例如 SVM)。
我想使用 Dataset API ( tf.contrib.data
) 和dataset.map()
:
# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
dataset = inputs(...) # This creates a dataset of (image, label) pairs
def map_example(image, label):
features = feature_extractor(image, trainable=False)
# Leaving out initialization from a checkpoint here...
return features, label
dataset = dataset.map(map_example)
return dataset
为数据集创建迭代器时,这样做会失败。
ValueError: Cannot capture a stateful node by value.
这是真的,网络的内核和偏差是变量,因此是有状态的。对于这个特定的示例,它们不必是。
有没有办法让 Ops 和特别是tf.Variable
对象无状态?
由于我正在使用tf.layers
我不能简单地将它们创建为常量,并且设置trainable=False
也不会创建常量,但不会将变量添加到GraphKeys.TRAINABLE_VARIABLES
集合中。