我正在联邦学习设置中实现 skip-gram 模型。我通过以下方式获取输入和标签:
train_inputs_embed = tf.nn.embedding_lookup(variables.weights, batch['target_id'])
train_labels = tf.reshape(batch['context_id'], [-1, 1])
当我将损失定义如下
loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=variables.nce_weights,
biases=variables.bias,
inputs=train_inputs_embed,
labels=train_labels,
num_sampled=5,
num_true=1,
num_classes=vocab_size))
我收到以下错误
ValueError: Shape must be rank 2 but is rank 3 for 'sampled_softmax_loss/concat_4' (op: 'ConcatV2') with input shapes: [?,1], [?,?,5], [].
但是,以下代码(取自 sampled_softmax_loss 函数的 eval 部分)适用于相同的输入和标签!
logits = tf.matmul(train_inputs_embed, tf.transpose(variables.nce_weights))
logits = tf.nn.bias_add(logits, variables.bias)
labels_one_hot = tf.one_hot(train_labels, vocab_size)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_one_hot, logits=logits))
如何解决此问题?