想象一下,我想训练模型,以最小化图像和查询之间的距离。一方面我有来自 CNN 的图像特征,另一方面我有从单词到嵌入向量的映射(例如 w2v):
def raw_data_generator():
for row in network_data:
yield (row["cnn"], row["w2v_indices"])
dataset = tf.data.Dataset.from_generator(raw_data_generator, (tf.float32, tf.int32))
dataset = dataset.prefetch(1000)
在这里我想创建批处理,但我想为 cnn 特征创建密集批处理,为 w2v 创建稀疏批处理,因为显然它具有可变长度(并且我想使用safe_embeddings_lookup_sparse)。密集有批处理功能,稀疏有.apply(tf.contrib.data.dense_to_sparse_batch(..))功能,但如何同时使用它们?