0

当使用 TensorFlow 1.x 和 TensorFlow hub 时,我们可以加载模块的规范来检查预期的输出形状(可能还有其他有用的规范!),如下所示:

spec = hub.load_module_spec("https://tfhub.dev/google/nnlm-en-dim128/1")
shape = spec.get_output_info_dict()['default'].get_shape()

当尝试对兼容 TF 2.0 的集线器模块执行相同操作时,我在调用时遇到以下错误消息load_module_spec

缺少支持的实现:loader(*('/tmp/tfhub_modules/82c4aaf4250ffb09088bd48368ee7fd00e5464fe',), **{})

是否有其他方法可以检查 TF 2.0 集线器模块的输出形状?

4

1 回答 1

1

对于 TensorFlow 2,TF Hub 将切换到提供 TF2 的原生基于对象的 SavedModels [ docRFC ]。tf.saved_model.load()如果它们已经在您的文件系统上,或者hub.load()从 URL 可选下载,则它们会被加载。这为您提供了一个恢复的Trackable对象,其__call__成员的行为类似于 a @tf.function,这意味着它具有一个或多个具体函数,每个函数都由 TF 图支持,并根据张量形状/dtypes 和非张量参数在它们之间分派。

在 TF2 的当前 alpha 版本中,如果您知道输入的允许 TensorSpec,您可以深入到输出,例如:

loaded_model = hub.load("https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1")
concrete_function = loaded_model.__call__.get_concrete_function(
    tf.TensorSpec((None,), tf.string))
print(concrete_function.output_shapes, ":",
      concrete_function.output_dtypes)

这给了我

(None, 128) : <dtype: 'float32'>
于 2019-05-24T14:57:17.307 回答