0

我正在联邦学习设置中实现 skip-gram 模型。我通过以下方式获取输入和标签:

train_inputs_embed = tf.nn.embedding_lookup(variables.weights, batch['target_id'])  
train_labels = tf.reshape(batch['context_id'], [-1, 1])

当我将损失定义如下

loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=variables.nce_weights,
                                                 biases=variables.bias,
                                                 inputs=train_inputs_embed,
                                                 labels=train_labels,
                                                 num_sampled=5,
                                                 num_true=1,
                                                 num_classes=vocab_size))

我收到以下错误

 ValueError: Shape must be rank 2 but is rank 3 for 'sampled_softmax_loss/concat_4' (op: 'ConcatV2') with input shapes: [?,1], [?,?,5], [].

但是,以下代码(取自 sampled_softmax_loss 函数的 eval 部分)适用于相同的输入和标签

logits = tf.matmul(train_inputs_embed, tf.transpose(variables.nce_weights))
logits = tf.nn.bias_add(logits, variables.bias)
labels_one_hot = tf.one_hot(train_labels, vocab_size)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_one_hot, logits=logits))

如何解决此问题?

4

1 回答 1

0

重塑 train_inputs_embed 解决了错误

train_inputs_embed = tf.reshape(tf.nn.embedding_lookup(variables.weights, batch['target_id']), [-1, embedding_size])
于 2020-02-05T20:13:21.440 回答