我遇到的一些教程使用随机初始化的嵌入矩阵进行描述,然后使用该tf.nn.embedding_lookup
函数获取整数序列的嵌入。我的印象是,由于embedding_matrix
是通过 获得tf.get_variable
的,优化器会添加适当的操作来更新它。
我不明白的是如何通过查找功能发生反向传播,这似乎是硬而不是软。这个操作的梯度是多少?它的输入ID之一?
我遇到的一些教程使用随机初始化的嵌入矩阵进行描述,然后使用该tf.nn.embedding_lookup
函数获取整数序列的嵌入。我的印象是,由于embedding_matrix
是通过 获得tf.get_variable
的,优化器会添加适当的操作来更新它。
我不明白的是如何通过查找功能发生反向传播,这似乎是硬而不是软。这个操作的梯度是多少?它的输入ID之一?
嵌入矩阵查找在数学上等价于单热编码矩阵的点积(参见这个问题),这是一种平滑的线性运算。
例如,这是对 index 的查找3
:
这是渐变的公式:
...其中左侧是负对数似然的导数(即目标函数),x
是输入词,W
是嵌入矩阵,delta
是误差信号。
tf.nn.embedding_lookup
进行了优化,因此不会发生 one-hot 编码转换,但反向传播根据相同的公式工作。