0

我正在查看这篇 Tensorflow 文章的源代码,该文章讨论了如何创建一个广泛而深入的学习模型。https://www.tensorflow.org/versions/r1.3/tutorials/wide_and_deep

这是python源代码的链接:https ://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py

它的目标是训练一个模型,根据人口普查信息中的数据,该模型将预测某人的年收入是否超过或低于 50ka 美元。

按照指示,我正在运行此命令来执行:

python wide_n_deep_tutorial.py --model_type=wide_n_deep

我得到的结果如下:

$ python wide_n_deep.py --model_type=wide_n_deep
Training data is downloaded to /tmp/tmp_pwqo2h8
Test data is downloaded to /tmp/tmph6jcimik
2018-01-03 05:34:12.236038: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
WARNING:tensorflow:enqueue_data was called with num_epochs and num_threads > 1. num_epochs is applied per thread, so this will produce more epochs than you probably intend. If you want to limit epochs, use one thread.
WARNING:tensorflow:enqueue_data was called with shuffle=False and num_threads > 1. This will create multiple threads, all reading the array/dataframe in order. If you want examples read in order, use one thread; if you want multiple threads, enable shuffling.
WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.
WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.
model directory = /tmp/tmp_ab6cfsf
accuracy: 0.808673
accuracy_baseline: 0.763774
auc: 0.841373
auc_precision_recall: 0.66043
average_loss: 0.418642
global_step: 2000
label/mean: 0.236226
loss: 41.8154
prediction/mean: 0.251593

在我在网上看到的各种文章中,它谈到了加载.ckpt文件。当我查看我的模型目录时,我看到了这些文件:

$ ls /tmp/tmp_ab6cfsf
checkpoint  eval  events.out.tfevents.1514957651.ml-1  graph.pbtxt  model.ckpt-1.data-00000-of-00001  model.ckpt-1.index  model.ckpt-1.meta  model.ckpt-2000.data-00000-of-00001  model.ckpt-2000.index  model.ckpt-2000.meta

我猜我会使用的是model.ckpt-1.meta,对吗?

但我也对如何使用和提供这个模型数据感到困惑。我在 Tensorflow 的网站上看过这篇文章:https ://www.tensorflow.org/versions/r1.3/programmers_guide/saved_model

上面写着“请注意,Estimators 会自动保存和恢复变量(在 model_dir 中)。” (不确定在这种情况下这意味着什么)

我如何生成人口普查数据格式的信息,除了工资,因为这是我们应该预测的?对我来说,如何使用两篇 Tensorflow 文章以便能够使用经过训练的模型进行预测并不明显。

4

1 回答 1

1

您可以查看 TensorFlow 团队的官方博客文章(第 1部分和第 3 部分),其中很好地解释了如何使用估算器。

他们特别解释了如何使用自定义输入进行预测。predict这使用了Estimators的内置方法:

estimator = tf.estimator.Estimator(model_fn, ...)

predict_input_fn = ...  # define this using tf.data

predict_results = estimator.predict(predict_input_fn)
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))

对于您的示例,我们可以使用附加的 csv 文件创建预测输入函数。假设我们有一个名为的 csv 文件,"predict.csv"其中包含三个示例(可能是示例的前三行,"test.csv"没有标签)。这将给出:

predict.csv

...跳过此行...
25, Private, 226802, 11th, 7, Never-married, Machine-op-inspct, Own-child, Black, Male, 0, 0, 40, United-States
38, Private, 89814, HS-grad, 9, Married-civ-spouse, Farming-fishing, 丈夫, White, Male, 0, 0, 50, United-United States
28, Local-gov, 336951, Assoc-acdm, 12, Married-civ -spouse, Protective-serv, 丈夫, White, Male, 0, 0, 40, 美国

estimator = build_estimator(FLAGS.model_dir, FLAGS.model_type)

def predict_input_fn(data_file):
    """Input builder function."""
    df_data = pd.read_csv(
        tf.gfile.Open(data_file),
        names=CSV_COLUMNS[:-1],  # remove the last name "income_bracket" that corresponds to the label
        skipinitialspace=True,
        engine="python",
        skiprows=1)
    # remove NaN elements
    df_data = df_data.dropna(how="any", axis=0)
    return tf.estimator.inputs.pandas_input_fn(x=df_data, y=None, shuffle=False)

predict_file_name = "wide_n_deep/predict.csv"
predict_results = estimator.predict(input_fn=predict_input_fn(predict_file_name))
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))
于 2018-01-05T20:16:19.977 回答