1

当我运行我的 tensorflow 模型时,我收到了这个错误InvalidArgumentError: Field 4 in record 0 is not a valid float: latency [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_STRING], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13)]] [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], [?], [?], [?], [?], [?], [?], [?], [?], [?]], output_types=[DT_STRING, DT_STRING, DT_FLOAT, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_STRING, DT_FLOAT, DT_FLOAT, DT_STRING, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

我相信这个问题出在预处理步骤中,该步骤创建了我的模型从中读取数据的 csv 文件,因为它认为它应该接收第一个Node模式,而是获得第二个Node模式。

这是我的预处理代码(我故意从查询中转换 dtypes 以确保它们被正确读取):

query = """
SELECT CAST(end_time AS STRING) AS end_time, CAST(device AS STRING) AS device, CAST(device_os AS STRING) AS device_os, CAST(device_os_version AS STRING) AS device_os_version, CAST(latency AS FLOAT) AS latency,
CAST(megacycles AS FLOAT) AS megacycles, CAST(cost AS FLOAT) AS cost, CAST(status AS STRING) AS Status, CAST(device_brand AS STRING) AS device_brand, CAST(device_family AS STRING) AS device_family,
CAST(browser_version AS STRING) AS browser_version, CAST(app AS STRING) AS app, CAST(ua_parse AS STRING) AS ua_parse
FROM [<mytable>:daily_logs.app_logs_data]
WHERE start_time >='2018-04-16'
GROUP BY end_time, device, device_os, device_os_version, latency, megacycles, cost, Status, device_brand, device_family, browser_version, app, ua_parse
"""
def preprocess_tft(inputs):
    import copy
    import numpy as np
    def center(x):
          return x - tft.mean(x)
    result = copy.copy(inputs) # shallow copy
    result['end_time'] = tft.string_to_int(inputs['end_time'])
    result['device'] = tft.string_to_int(inputs['device'])
    result['device_os'] = tft.string_to_int(inputs['device_os'])
    result['device_os_version'] = tft.string_to_int(inputs['device_os_version'])
    result['latency_tft'] = center(inputs['latency'])
    result['megacycles_tft'] = center(inputs['megacycles'])
    result['cost_tft'] = center(inputs['cost'])
    result['Status'] = tft.string_to_int(inputs['Status'])
    result['device_brand'] = tft.string_to_int(inputs['device_brand'])
    result['device_family'] = tft.string_to_int(inputs['device_family'])
    result['browser_version'] = tft.string_to_int(inputs['browser_version'])
    result['app'] = tft.string_to_int(inputs['app'])
    result['ua_parse'] = tft.string_to_int(inputs['ua_parse'])
    return result
    #return inputs

def cleanup(rowdict):
    import copy, hashlib
    CSV_COLUMNS ='end_time,device,device_os,device_os_version,latency,megacycles,cost,Status,device_brand,device_family,browser_version,app,ua_parse'.split(',')
    STR_COLUMNS = 'key,end_time,device,device_os,device_os_version,Status,device_brand,device_family,browser_version,app,ua_parse'.split(',')
    FLT_COLUMNS = 'latency,megacycles,cost'.split(',')

    # add any missing columns, and correct the types
    def tofloat(value, ifnot):
      try:
        return float(value)
      except (ValueError, TypeError):
        return ifnot

    result = {
      k : str(rowdict[k]) if k in rowdict else 'None' for k in STR_COLUMNS
    }
    result.update({
        k : tofloat(rowdict[k], -99) if k in rowdict else -99 for k in FLT_COLUMNS
      })    

    # cleanup: write out only the data we that we want to train on
    if result['latency'] > 0 and result['megacycles'] > 0 and result['cost'] > 0:
      data = ','.join([str(result[k]) for k in CSV_COLUMNS])
      result['key'] = hashlib.sha224(data).hexdigest()
      yield result

def preprocess(query, in_test_mode):
  import os
  import os.path
  import tempfile
  import tensorflow as tf
  from apache_beam.io import tfrecordio
  from tensorflow_transform.coders import example_proto_coder
  from tensorflow_transform.tf_metadata import dataset_metadata
  from tensorflow_transform.tf_metadata import dataset_schema
  from tensorflow_transform.beam.tft_beam_io import transform_fn_io

  job_name = 'preprocess-log-features' + '-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')    
  if in_test_mode:
    import shutil
    print 'Launching local job ... hang on'
    OUTPUT_DIR = './preproc_tft'
    shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
  else:
    print 'Launching Dataflow job {} ... hang on'.format(job_name)
    OUTPUT_DIR = 'gs://{0}/logs2/preproc_tft/'.format(BUCKET)
    import subprocess
    subprocess.call('gsutil rm -r {}'.format(OUTPUT_DIR).split())

  options = {
    'staging_location': os.path.join(OUTPUT_DIR, 'tmp', 'staging'),
    'temp_location': os.path.join(OUTPUT_DIR, 'tmp'),
    'job_name': job_name,
    'project': PROJECT,
    'max_num_workers': 24,
    'teardown_policy': 'TEARDOWN_ALWAYS',
    'no_save_main_session': True,
    'requirements_file': 'requirements.txt'
  }
  opts = beam.pipeline.PipelineOptions(flags=[], **options)
  if in_test_mode:
    RUNNER = 'DirectRunner'
  else:
    RUNNER = 'DataflowRunner'

  # set up metadata  
  raw_data_schema = {
    colname : dataset_schema.ColumnSchema(tf.string, [], dataset_schema.FixedColumnRepresentation())
                   for colname in 'key,end_time,device,device_os,device_os_version,Status,device_brand,device_family,browser_version,app,ua_parse'.split(',')
  }
  raw_data_schema.update({
      colname : dataset_schema.ColumnSchema(tf.float32, [], dataset_schema.FixedColumnRepresentation())
                for colname in 'latency,megacycles,cost'.split(',')
    })
  raw_data_metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema(raw_data_schema))

  def read_rawdata(p, step, test_mode):
    if step == 'train':
        selquery = 'SELECT * FROM ({})'.format(query)
    else:
      selquery = 'SELECT * FROM ({})'.format(query)
    if in_test_mode:
        selquery = selquery + ' LIMIT 100'
    #print 'Processing {} data from {}'.format(step, selquery)
    return (p 
          | '{}_read'.format(step) >> beam.io.Read(beam.io.BigQuerySource(query=selquery, use_standard_sql=False))
          | '{}_cleanup'.format(step) >> beam.FlatMap(cleanup)
                   )

  # run Beam  
  with beam.Pipeline(RUNNER, options=opts) as p:
    with beam_impl.Context(temp_dir=os.path.join(OUTPUT_DIR, 'tmp')):

      # analyze and transform training       
      raw_data = read_rawdata(p, 'train', in_test_mode)
      raw_dataset = (raw_data, raw_data_metadata)
      transformed_dataset, transform_fn = (
          raw_dataset | beam_impl.AnalyzeAndTransformDataset(preprocess_tft))
      transformed_data, transformed_metadata = transformed_dataset
      _ = transformed_data | 'WriteTrainData' >> tfrecordio.WriteToTFRecord(
          os.path.join(OUTPUT_DIR, 'train'),
          coder=example_proto_coder.ExampleProtoCoder(
              transformed_metadata.schema))

      # transform eval data
      raw_test_data = read_rawdata(p, 'eval', in_test_mode)
      raw_test_dataset = (raw_test_data, raw_data_metadata)
      transformed_test_dataset = (
          (raw_test_dataset, transform_fn) | beam_impl.TransformDataset())
      transformed_test_data, _ = transformed_test_dataset
      _ = transformed_test_data | 'WriteTestData' >> tfrecordio.WriteToTFRecord(
          os.path.join(OUTPUT_DIR, 'eval'),
          coder=example_proto_coder.ExampleProtoCoder(
              transformed_metadata.schema))
      _ = (transform_fn
           | 'WriteTransformFn' >>
           transform_fn_io.WriteTransformFn(os.path.join(OUTPUT_DIR, 'metadata')))

  job = p.run()
  if in_test_mode:
    job.wait_until_finish()
    print "Done!"

preprocess(query, in_test_mode=False)

有没有办法在将 dtypes 发送到 Beam 管道之前打印出它,以便我可以检查 dtypes 数组并确保它有效?代码中是否存在导致 dtype 不同或以不同顺序排列然后在CSV_COLUMNS变量中指定的内容?

4

0 回答 0