0

有没有办法将 jax npz 预训练的权重转换为 kers/tf.keras h5 格式的权重?

在网上找不到任何东西。

谢谢

4

1 回答 1

0

npz从格式转换为格式的最直接方法是将h5数据加载到内存中,然后重写。

这是一个简短的例子:

import jax.numpy as jnp
from jax import random
import h5py

# Create some random weights
key = random.PRNGKey(1701)
weights = random.normal(key, shape=(100,))

# Save to an npz file
jnp.savez('weights.npz', weights=weights)

# Load the npz and convert to h5
data = jnp.load('weights.npz')
with h5py.File('weights.h5', 'w') as hf:
    hf.create_dataset('weights', data=data['weights'])

请注意,这将取决于 npz 文件的内容和生成的 h5 文件的所需结构。

于 2021-02-16T05:45:53.167 回答