问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
tensorflow - 将 npz jax 权重转换为 keras h5 权重
有没有办法将 jax npz 预训练的权重转换为 kers/tf.keras h5 格式的权重?
在网上找不到任何东西。
谢谢
python - 创建一个零的 3D 张量,在 numpy/jax 中的每个切片上随机放置一个“1”
例如,我需要创建一个像这样 (5,3,2) 的 3D 张量
每个切片中都应该随机放置一个“一”(如果您认为张量是一条面包)。这可以使用循环来完成,但我想对这部分进行矢量化。
tensorflow - jax 库中的“xla_client”到底是什么?
如果您阅读jax 源代码,您会遇到一个名为xla_client
. 经常这样导入
这意味着这xla_client
是一个 python 模块,但我找不到任何具有该名称的文件或对该名称的变量的引用。
我假设它与https://pypi.org/project/jaxlib/相关,但这个包只是链接回 jax 源代码。
任何人都可以提示我吗?
jax - 将一段代码与 jax 跟踪隔离开来
提前为这个问题的含糊程度表示歉意(不幸的是,我对 jax 跟踪的工作原理知之甚少,无法更准确地表述它),但是:有没有办法将函数或代码块与 jax 跟踪完全隔离?
对于上下文,我具有以下形式的功能:
本质上,我想在进行任何 jax 转换时调用g(x, z)
, 并将其视为常量。z
但是,设置参数z
非常尴尬,因此使用辅助函数h
将更易于指定的输入y
转换为g
. 我希望 jax 将h
其视为不可追踪的黑匣子,因此jit(lambda x: f(x, y0))
对特定y0
对象的操作与首先使用 计算z0 = h(y0)
,numpy
然后执行jit(lambda x: g(x, z0))
(以及与grad
或任何其他函数转换类似)相同。
在我的代码中,我已经编写h
了只使用标准numpy
(我认为这可能会导致黑盒行为),但是 的编译时间jit(lambda x: f(x, y0))
明显长于jit(lambda x: g(x, z0))
for的编译时间z0 = h(y0)
。我有一种感觉,编译时间可能与 jax 跟踪中的许多循环有关h
,但我不确定。
一些附加说明:
- 以一种对 jax 友好的方式编写
h
会很尴尬(输入格式参差不齐,大量循环/条件,输出形状取决于输入值等)并且最终比它的价值更麻烦,因为该函数执行起来非常便宜,我不知道永远不需要区分它(输入数据是基于整数的)。
想法?
为清楚起见编辑添加:我知道如果例如f
是顶级功能,则可能有解决方法。在这种情况下,让用户首先调用h
以“预编译”对 jax 友好的输入g
,然后自由地执行他们想要的任何 jax 转换,这并不是什么大问题lambda x: g(x, z0)
。但是,我在想象这样的情况,我们有许多想要链接在一起的函数,它们具有相同的结构f
,其中有一些对 jax 不友好的输入/计算,但这些输入将始终被视为 jax计算的一部分。原则上,人们总是可以提取这些预先计算来设置 jax 的东西,但是如果我们有一个这种类型的函数的非平凡集合,它们会相互调用,这似乎很困难。
是否有某种方法可以控制如何f
跟踪,以便在跟踪时知道只评估z=h(y)
(而不是跟踪h
)然后继续跟踪g(x, z)
?
python - 了解其梯度函数中的 JAX argnums 参数
我试图了解argnums
JAX 渐变函数中的行为。假设我有以下功能:
我正在通过以下方式获取渐变:
argnums= (0,1)
在这种情况下,但这意味着什么?关于哪些变量计算梯度?如果我改用它会有什么区别argnums=0
?另外,我可以使用相同的函数来获取 Hessian 矩阵吗?
我查看了有关它的JAX 帮助部分,但无法弄清楚
tensorflow - 重新实现 bert-style pooler 会引发形状错误,就好像仍然需要长度维度一样
我已经训练了一个现成的 Transformer()。
现在我想使用编码器来构建分类器。为此,我只想使用第一个令牌的输出(bert-style cls-token-result)并通过密集层运行它。
我所做的:
形状:
编码
器
给我形状(64、50、512)64 = batch_size,
50 = seq_len,
512 = model_dim
pooler 为我提供了符合预期和期望的形状(64, 512)。
密集层应该为每个批次成员采用 512 个维度并分类超过 7 个类。但我猜 trax/jax 仍然希望它的长度为 seq_len (50)。
我想念什么?
完整追溯:
python - JAX vmap 行为
我试图了解 JAX vmap 的行为,所以我编写了以下代码:
输出是:
我知道唯一被改变的输入是b
,但是有人可以解释为什么会这样吗?在我对函数进行矢量化后,点积的行为如何?
gpu - 分析 JAX 代码:什么是 redzone_checker,为什么要花这么多时间?
我找到了这篇文章,但仍不清楚redzone_checker
内核在做什么以及为什么。具体来说,它是否应该占用我应用程序运行时间的 90% 以上?TensorBoard 报告说它占用了我的 JAX 代码的绝大部分运行时间,我想知道
- 实际上是这个内核花费了太多时间,还是这是使用 TensorBoard 分析 JAX 的副作用(即,输出在某种程度上具有误导性)?
- 有没有办法减少
redzone_checker
内核花费的时间?这甚至是个好主意吗?
提前感谢您的任何见解。
python-3.x - RuntimeError:未实现:未找到 DNN 库
我试图通过谷歌实现视觉转换器,我在推理过程中遇到了这个错误:
我关注了这篇文章,但没有帮助。我应该怎么办?
python - 如何在 jit 编译的 jax 代码中执行非整数索引运算?
如果我们对数组索引执行非整数计算(然后转换为 int() ),似乎我们仍然无法将结果用作 jit 编译的 jax 代码中的有效索引。我们如何解决这个问题?
以下是一个最小的示例。具体问题:命令 jnp.diag_indices(d) 是否可以在不向 fun() 传递额外参数的情况下工作
在木星单元中运行它: