我想构造一批批大小为 16 的数据,使用tf.data
,其中[:8]
是一种数据 A,[8:16]
是一种数据 B。
没有tf.data
. 如果使用tf.data
,代码可能是:
def _decode_record(record, name_to_features):
example = tf.parse_single_example(record, name_to_features)
return example
dataA = tf.data.TFRecordDataset(input_files)
dataA = dataA.apply(
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size)
)
接下来怎么做?我尝试:
dataB = tf.data.TFRecordDataset(input_files2)
dataB = dataB.apply(
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size)
)
dataC = dataA.concatenate(dataB)
但是concatenate
是:将整个数据集附加dataB
到dataA
.
对于concatenate
,请注意 和name_to_features
应该相同dataA
,dataB
这意味着我应该填充很多虚拟数据。
我不想在of中使用tf.cond
或tf.where
判断不同的数据,这也很难调试。model_fn
tf.estimator