6

我有一些句子,我正在为其创建一个嵌入,它非常适合相似性搜索,除非句子中有一些真正不寻常的单词。

在这种情况下,真正不寻常的词实际上包含句子中任何词的最相似信息,但由于该词显然不在模型的词汇表中,所有这些信息在嵌入过程中都丢失了。

我想获得 GUSE 嵌入模型已知的所有单词的列表,以便我可以将那些已知单词从我的句子中屏蔽掉,只留下“新颖”的单词。

然后我可以在我的目标语料库中对那些新词进行精确的词搜索,并为我的相似句子搜索实现可用性。

例如“我喜欢使用 Xapian!” 嵌入为“我喜欢使用 UNK”。

如果我只对“Xapian”进行关键字搜索而不是语义相似性搜索,我将获得比使用 GUSE 和向量 KNN 更多的相关结果。

关于如何提取 GUSE 已知/使用的词汇的任何想法?

4

2 回答 2

1

我假设你已经安装了 tensorflow 和 tensorflow_hub,并且你已经下载了模型。

重要提示:我假设您正在查看https://tfhub.dev/google/universal-sentence-encoder/4!不能保证不同版本的对象图看起来相同,很可能需要进行修改。

找到它在磁盘上的位置——/tmp/tfhub_modules除非你设置TFHUB_CACHE_DIR环境变量(Windows/Mac 有不同的位置),否则它就在某处。该路径应包含一个名为 的文件saved_model.pb,该文件是模型,使用协议缓冲区进行序列化。

不幸的是,字典是在模型的 Protocol Buffers 文件中序列化的,而不是作为外部资产,所以我们必须加载模型并从中获取变量。

策略是使用 tensorflow 的代码来反序列化文件,然后沿着序列化的对象树一直向下移动到字典。

import importlib

MODEL_PATH = 'path/to/model/dir' # e.g. '/tmp/tfhub_modules/063d866c06683311b44b4992fd46003be952409c/'

# Use the tensorflow internal Protobuf loader. A regular import statement will fail.
loader_impl = importlib.import_module('tensorflow.python.saved_model.loader_impl')

saved_model = loader_impl.parse_saved_model(MODEL_PATH)

# reach into the object graph to get the tensor
graph = saved_model.meta_graphs[0].graph_def
function = graph.library.function
node_type, node_value = function[5].node_def
# if you print(node_type) you'll see it's called "text_preprocessor/hash_table"
# as well as get insight into this branch of the object graph we're looking at
words_tensor = node_value.attr.get("value").tensor

word_list = [i.decode('utf-8') for i in words_tensor.string_val]
print(len(word_list)) # -> 400004

一些有帮助的资源:

  1. 与更改词汇表有关的GitHub 问题
  2. 从该问题链接的 Tensorflow Google-group 线程

额外说明

尽管 GitHub 问题可能会让您想到,这里的 400k 单词并不是GloVe 400k 词汇表。您可以通过下载GloVe 6B 嵌入(文件链接)、提取glove.6B.50d.txt、然后使用以下代码比较两个字典来验证这一点:

with open('/path/to/glove.6B.50d.txt') as f:
    glove_vocabulary = set(line.strip().split(maxsplit=1)[0] for line in f)

USE_vocabulary = set(word_list) # from above

print(len(USE_vocabulary - glove_vocabulary)) # -> 281150

检查不同的词汇本身就很有趣,例如,为什么 GloVe 有一个条目是“287.9”?

于 2020-11-10T19:39:39.593 回答
1

我将@Roee Shenberg 的较早答案和此处提供的解决方案结合起来提出了适用于 USE v4 的解决方案:

import importlib
loader_impl = importlib.import_module('tensorflow.python.saved_model.loader_impl')

saved_model = loader_impl.parse_saved_model("/tmp/tfhub_modules/063d866c06683311b44b4992fd46003be952409c/")
graph = saved_model.meta_graphs[0].graph_def

fns = [f for f in saved_model.meta_graphs[0].graph_def.library.function if "ptb" in str(f).lower()];
print(len(fns)) # should be 1

nodes_with_sp = [n for n in fns[0].node_def if n.name == "Embeddings_words"]
print(len(nodes_with_sp)) # should be 1

words_tensor = nodes_with_sp[0].attr.get("value").tensor

word_list = [i.decode('utf-8') for i in words_tensor.string_val]
print(len(word_list)) # should be 400004

如果您只是对我在这里上传的文字感到好奇

于 2021-10-21T17:06:16.690 回答