数据集有很多列,我使用 input_fn 将数据提供给模型:
def input_fn(data_file, num_epochs, shuffle, is_eval):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1,
sep = '\t')
df_data["sentence"] = df_data["tags"].apply(lambda x : x.split(','))
labels = df_data["is_click"]
return tf.estimator.inputs.numpy_input_fn(
x={
"book": df_data['book'],
"author": df_data['author'],
"sentence": df_data['sentence']
},
y=labels,
batch_size=2048,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5
)
特征 [sentence] 将被拆分为列表,因为我想将句子转换为 one-hot 编码,例如:https ://www.tensorflow.org/versions/r1.4/api_docs/python/tf/feature_column/indicator_column
但是如果我拆分句子,我会得到如下错误:
tensorflow.python.framework.errors_impl.InternalError: Unable to get element as bytes.