我正在运行张量流变换/光束管道来加载/预处理并保存为 TFRecords。然后加载这些记录。在 Tensorflow Transform 的预处理过程中,我想填充一个稀疏张量。因此,我想将其转换为密集,填充并将其转换回稀疏。
代码看起来有点像这样:
import tensorflow_transform as tft
import tensorflow as tf
#...
def preprocess_fn(input_features):
output_features = {}
output_features[CATEGORICAL_FEATURE_NAMES] = tft.compute_and_apply_vocabulary(...)
#dense = tf.sparse.to_dense(output_features[CATEGORICAL_FEATURE_NAMES])
## do something with dense
#output_features[CATEGORICAL_FEATURE_NAMES] = tf.contrib.layers.dense_to_sparse(dense)
return output_features
要加载 TFRecords,我使用以下函数:
def tfrecords_input_fn(files_name_pattern, transformed_metadata,
mode=tf.estimator.ModeKeys.EVAL,
num_epochs=1,
batch_size=64):
dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=files_name_pattern,
batch_size=batch_size,
features=transformed_metadata.schema.as_feature_spec(),
reader=tf.data.TFRecordDataset,
num_epochs=num_epochs,
shuffle=True if mode == tf.estimator.ModeKeys.TRAIN else False,
shuffle_buffer_size=1 + (batch_size * 2),
prefetch_buffer_size=1,
drop_final_batch=True
)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
target = features.pop(TARGET_FEATURE_NAME)
return features, target
运行整个管道(加载原始数据、转换、保存 TFRecords,然后加载它们以将它们打印到屏幕上)工作正常,但取消注释“preprocess_fn”中的 2 行会导致以下错误:
文件“.../lib/python3.6/site-packages/tensorflow_transform/impl_helper.py”,第 262 行,在 to_instance_dicts raise ValueError('Encountered a SparseTensorValue that cannot be ' ValueError: Encountered a SparseTensorValue that cannot be decoded by ListColumnRepresentation .
...
ValueError:遇到无法由 ListColumnRepresentation 解码的 SparseTensorValue。[在运行“%s - Transform/ConvertAndUnbatch”时]
有没有人对此代码有建议或对我错过了什么有任何提示?很感谢任何形式的帮助!
最好的,多米尼克