我有一个多类 NLP 分类问题,我需要针对大约 1000 个类别标签训练大约 100 万个文本样本。随着未来的数据集被输入管道,这个唯一类别标签的数量将略有变化。
为此,我需要将 tf.one_hot 中的深度参数设置为该训练实例的动态确定的唯一标签的数量。
要知道唯一标签的全部数量,我知道我需要完整地遍历数据。所以,我被困的地方是如何计算这个数字。
我认为 tft.size 适合获得这个完整的通行证,但它似乎不起作用。当我硬编码 1000 时,您可以在下面看到它工作正常:
labels = inputs[LABEL_KEY]
sparse_labels_tokens = tft.compute_and_apply_vocabulary(labels, vocab_filename=LABEL_VOCAB_FILE_NAME)
dense_labels_tokens = tf.sparse.to_dense(sparse_labels_tokens)
#labels_count = tf.cast( tft.size(dense_labels_tokens), tf.int32 ) #FIXME
labels_count = 1000
labels_one_hot = tf.one_hot(dense_labels_tokens, depth=labels_count)
labels_indicators = tf.reduce_max(labels_one_hot, axis=1)
outputs[transformed_name(LABEL_KEY)] = labels_indicators
outputs[LABEL_KEY] = _fill_in_missing(inputs[LABEL_KEY])
给予:
# Iterate over the first few tfrecords and decode them.
for tfrecord in dataset.take(5):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
pprint.pprint(example)```
feature {
key: "label_xf"
value {
float_list {
value: 0.0
value: 0.0
value: 1.0
value: 0.0
但是,如果我改为使用 tft.size 我会收到以下错误:
...~/.local/lib/python3.6/site-packages/tensorflow_transform/schema_inference.py in _infer_feature_schema_common(features, tensor_ranges, feature_annotations, global_annotations)
241 domains[name] = schema_pb2.IntDomain(
242 min=min_value, max=max_value, is_categorical=True)
--> 243 feature_spec = _feature_spec_from_batched_tensors(features)
244
245 schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)
~/.local/lib/python3.6/site-packages/tensorflow_transform/schema_inference.py in _feature_spec_from_batched_tensors(tensors)
86 'Feature {} ({}) had invalid shape {} for FixedLenFeature: apart '
87 'from the batch dimension, all dimensions must have known size'
---> 88 .format(name, tensor, shape))
89 feature_spec[name] = tf.io.FixedLenFeature(shape.as_list()[1:],
90 tensor.dtype)
ValueError: Feature label_xf (Tensor("Max:0", shape=(None, None), dtype=float32)) had invalid shape (None, None) for FixedLenFeature: apart from the batch dimension, all dimensions must have known size
This cell will be skipped during export to pipeline.
我可以将深度硬编码到 1500 并交叉手指说标签数量永远不会超过这个值,但我不确定如果我这样做的话我是否能够和自己一起生活:(