0

在 tensorflow 的人口普查示例中, LABEL_COLUMN(income_bracket) 具有预定义的值[' <=50K', ' >50K']。它是一个分类基列。

1) 如何修改LABEL_COLUMN是具有浮点值的连续基列的 model.py?

2)可以更新这个程序以打印预测值吗?它仅返回准确率百分比。

def parse_label_column(label_string_tensor):
  # Build a Hash Table inside the graph
  table = tf.contrib.lookup.string_to_index_table_from_tensor(
      tf.constant(LABELS))

  # Use the hash table to convert string labels to ints
  return table.lookup(label_string_tensor)

def generate_input_fn(filenames,
                      num_epochs=None,
                      shuffle=True,
                      skip_header_lines=0,
                      batch_size=40):
  """Generates an input function for training or evaluation.
  Returns:
      A function () -> (features, indices) where features is a dictionary of
        Tensors, and indices is a single Tensor of label indices.
  """
  def _input_fn():
    files = tf.concat([
      tf.train.match_filenames_once(filename)
      for filename in filenames
    ], axis=0)

    filename_queue = tf.train.string_input_producer(
        files, num_epochs=num_epochs, shuffle=shuffle)
    reader = tf.TextLineReader(skip_header_lines=skip_header_lines)

    _, rows = reader.read_up_to(filename_queue, num_records=batch_size)

    # DNNLinearCombinedClassifier expects rank 2 tensors.
    row_columns = tf.expand_dims(rows, -1)
    columns = tf.decode_csv(row_columns, record_defaults=CSV_COLUMN_DEFAULTS)
    features = dict(zip(CSV_COLUMNS, columns))

    # Remove unused columns
    for col in UNUSED_COLUMNS:
      features.pop(col)

    if shuffle:
      # This operation maintains a buffer of Tensors so that inputs are
      # well shuffled even between batches.
      features = tf.train.shuffle_batch(
          features,
          batch_size,
          capacity=batch_size * 10,
          min_after_dequeue=batch_size*2 + 1,
          num_threads=multiprocessing.cpu_count(),
          enqueue_many=True,
          allow_smaller_final_batch=True
      )
    label_tensor = parse_label_column(features.pop(LABEL_COLUMN))
    return features, label_tensor
  return _input_fn
4

1 回答 1

1

要使“标签”浮动,您需要确保标签列的默认值是浮动的。需要进行以下更改:

CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
                       [0], [0], [0], [''], [0.0]]
label_tensor = features.pop(LABEL_COLUMN)

(您可能需要考虑s/LABEL_COLUMN/INCOME_COLUMN/g

于 2017-05-06T14:30:33.257 回答