0

我有一个包含 1000 个项目的数据集。在针对数据训练模型之前,我对数据进行了标准化。

我现在想使用该模型进行预测。但是,据我了解,我需要对将提供给我需要预测的模型的输入进行标准化。为了做到这一点,我需要在训练时计算出平均值和标准差。

虽然我可以将它打印到控制台,但如何“保存”它 - 以供以后使用?我试图在这里了解如何保存在训练数据标准化时使用的均值和标准差的过程 - 这样我就可以在进行预测时再次使用它。

4

1 回答 1

0

我确定我们可以首先通过以下方式获得张量的数组表示:

// tensor here is the tensor variable that contains the tensor
const tensorAsArray = tensor.arraySync()

然后,我们像任何其他字符串一样将其保存到文件中

fs.writeFile(myFilePath, JSON.stringify(tensorAsArray), 'utf-8')

要将其读回并将其用作张量,我们将做相反的事情:

const tensorAsArray = JSON.parse(fs.readFile(myFilePath, 'utf-8'))
const tensor = tf.tensor(tensorAsArray)

这让我可以保存平均值和标准以供以后使用。

于 2021-06-14T17:25:55.237 回答