0

我在 Python 中创建了一个tf.estimator模型和一个管道,并将其以 TF 2.1 的格式保存。由于 tfjs-node 不支持 int64 或 float64 类型,因此无法加载模型。tf.datatf.saved_model

在 Tensorboard 上,我观察到一些输入管道 Python 变量被自动声明为 64 位类型。

在此处输入图像描述

例如,batch_sizeepochs以上。如何避免此问题并tf.estimator在 tfjs-node 中加载模型而不进行转换?

为了重现,

4

1 回答 1

1

由于 tfjs-node 不支持 INT64 和 FLOAT64,如果输入/输出 tensor dtypes 为 INT64/FLOAT64,则无法直接加载和执行模型。一种解决方法是使用张量转换函数包装模型:

  1. 在python中,加载原始模型,并创建一个输入张量为INT32或FLOAT32的新模型
  2. 在新模型中将输入张量转换为 DTYPE INT64/FLOAT64
  3. 将转换后的张量提供给原始模型,并返回模型输出(如果需要,转换值 dtype)
  4. 将此新模型导出为 TF SavedModel

因此,通过这种包装,tfjs-node 支持模型的输入/输出 dtype,尽管它可能会失去一些准确性。

于 2020-03-26T20:45:36.807 回答