1

我正在使用 tensorflow-hub、tensorflow-estimators 和 tensorflow-data 构建模型分类。

我的火车函数正在返回数据集,其model_fn定义如下:

def train_input_fn():
      return dataset_input_fn(DATASET_TRAIN_PATH)

def model_fn(features, labels, mode, params): 
  logging.info("model_fn")
  # module is imported from tf-hub
  return head.create_estimator_spec (features, mode, ...)

与Damien的代码非常相似。

代码环境为:Python 2,Google cloud datalab,tf.version为1.12。被触发的错误model_fn是不期望标签参数(可能由tf-data数据集生成)。返回数据集的model_fn给定签名应该是什么?input_fn

请提出任何想法。

非常感谢,

埃拉兰

4

0 回答 0