0

似乎tf.lookup.experimental.DenseHashTable不能保存向量,我找不到如何使用它的示例。

4

1 回答 1

1

您可以在下面找到 Tensorflow 中向量字典的简单实现。这也是使用tf.lookup.experimental.DenseHashTableand的一个例子tf.TensorArray

如前所述,向量不能保存在 中tf.lookup.experimental.DenseHashTable,因此tf.TensorArray用于保存实际向量。

当然,这是一个简单的例子,它不包括删除字典中的条目——这个操作需要对数组的空闲单元进行一些管理。此外,您应该阅读相应的 API 页面tf.lookup.experimental.DenseHashTable以及tf.TensorArray如何根据需要调整它们。

import tensorflow as tf


class DictionaryOfVectors:

  def __init__(self, dtype):
    empty_key = tf.constant('')
    deleted_key = tf.constant('deleted')

    self.ht = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string,
                                                    value_dtype=tf.int32,
                                                    default_value=-1,
                                                    empty_key=empty_key,
                                                    deleted_key=deleted_key)
    self.ta = tf.TensorArray(dtype, size=0, dynamic_size=True, clear_after_read=False)
    self.inserts_counter = 0

  @tf.function
  def insertOrAssign(self, key, vec):
    # Insert the vector to the TensorArray. The write() method returns a new
    # TensorArray object with flow that ensures the write occurs. It should be 
    # used for subsequent operations.
    with tf.init_scope():
      self.ta = self.ta.write(self.inserts_counter, vec)

      # Insert the same counter value to the hash table
      self.ht.insert_or_assign(key, self.inserts_counter)
      self.inserts_counter += 1

  @tf.function
  def lookup(self, key):
    with tf.init_scope():
      index = self.ht.lookup(key)
      return self.ta.read(index)

dictionary_of_vectors = DictionaryOfVectors(dtype=tf.float32)
dictionary_of_vectors.insertOrAssign('first', [1,2,3,4,5])
print(dictionary_of_vectors.lookup('first'))

这个例子有点复杂,因为 insert 和 lookup 方法用 @tf.function. 因为方法改变了在它们之外定义的变量,所以tf.init_scope()使用了。您可能会问该方法发生了什么变化,lookup()因为它实际上只从哈希表和数组中读取。原因是在图形模式下,lookup()调用返回的索引是一个张量,而在 TensorArray 实现中,有一行包含以下if index < 0:内容失败:

OperatorNotAllowedInGraphError:不允许将 atf.Tensor用作 Python 。bool

tf.init_scope()正如其 API 文档中所解释的,当我们使用tf.function. 所以在那种情况下,索引不是张量而是标量。

于 2020-05-28T17:40:14.210 回答