我正在使用 tensorflow-hub、tensorflow-estimators 和 tensorflow-data 构建模型分类。
我的火车函数正在返回数据集,其model_fn
定义如下:
def train_input_fn():
return dataset_input_fn(DATASET_TRAIN_PATH)
def model_fn(features, labels, mode, params):
logging.info("model_fn")
# module is imported from tf-hub
return head.create_estimator_spec (features, mode, ...)
与Damien的代码非常相似。
代码环境为:Python 2,Google cloud datalab,tf.version
为1.12。被触发的错误model_fn
是不期望标签参数(可能由tf-data
数据集生成)。返回数据集的model_fn
给定签名应该是什么?input_fn
请提出任何想法。
非常感谢,
埃拉兰