问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
1 回答
181 浏览

tensorflow - 将 npz jax 权重转换为 keras h5 权重

有没有办法将 jax npz 预训练的权重转换为 kers/tf.keras h5 格式的权重?

在网上找不到任何东西。

谢谢

0 投票
3 回答
230 浏览

python - 创建一个零的 3D 张量,在 numpy/jax 中的每个切片上随机放置一个“1”

例如,我需要创建一个像这样 (5,3,2) 的 3D 张量

每个切片中都应该随机放置一个“一”(如果您认为张量是一条面包)。这可以使用循环来完成,但我想对这部分进行矢量化。

0 投票
1 回答
218 浏览

tensorflow - jax 库中的“xla_client”到底是什么?

如果您阅读jax 源代码,您会遇到一个名为xla_client. 经常这样导入

这意味着这xla_client是一个 python 模块,但我找不到任何具有该名称的文件或对该名称的变量的引用。

我假设它与https://pypi.org/project/jaxlib/相关,但这个包只是链接回 jax 源代码。

任何人都可以提示我吗?

0 投票
1 回答
112 浏览

jax - 将一段代码与 jax 跟踪隔离开来

提前为这个问题的含糊程度表示歉意(不幸的是,我对 jax 跟踪的工作原理知之甚少,无法更准确地表述它),但是:有没有办法将函数或代码块与 jax 跟踪完全隔离?

对于上下文,我具有以下形式的功能:

本质上,我想在进行任何 jax 转换时调用g(x, z), 并将其视为常量。z但是,设置参数z非常尴尬,因此使用辅助函数h将更易于指定的输入y转换为g. 我希望 jax 将h其视为不可追踪的黑匣子,因此jit(lambda x: f(x, y0))对特定y0对象的操作与首先使用 计算z0 = h(y0)numpy然后执行jit(lambda x: g(x, z0))(以及与grad或任何其他函数转换类似)相同。

在我的代码中,我已经编写h了只使用标准numpy(我认为这可能会导致黑盒行为),但是 的编译时间jit(lambda x: f(x, y0))明显长于jit(lambda x: g(x, z0))for的编译时间z0 = h(y0)。我有一种感觉,编译时间可能与 jax 跟踪中的许多循环有关h,但我不确定。

一些附加说明:

  • 以一种对 jax 友好的方式编写h会很尴尬(输入格式参差不齐,大量循环/条件,输出形状取决于输入值等)并且最终比它的价值更麻烦,因为该函数执行起来非常便宜,我不知道永远不需要区分它(输入数据是基于整数的)。

想法?

为清楚起见编辑添加:我知道如果例如f是顶级功能,则可能有解决方法。在这种情况下,让用户首先调用h以“预编译”对 jax 友好的输入g,然后自由地执行他们想要的任何 jax 转换,这并不是什么大问题lambda x: g(x, z0)。但是,我在想象这样的情况,我们有许多想要链接在一起的函数,它们具有相同的结构f,其中有一些对 jax 不友好的输入/计算,但这些输入将始终被视为 jax计算的一部分。原则上,人们总是可以提取这些预先计算来设置 jax 的东西,但是如果我们有一个这种类型的函数的非平凡集合,它们会相互调用,这似乎很困难。

是否有某种方法可以控制如何f跟踪,以便在跟踪时知道只评估z=h(y)(而不是跟踪h)然后继续跟踪g(x, z)

0 投票
1 回答
473 浏览

python - 了解其梯度函数中的 JAX argnums 参数

我试图了解argnumsJAX 渐变函数中的行为。假设我有以下功能:

我正在通过以下方式获取渐变:

argnums= (0,1)在这种情况下,但这意味着什么?关于哪些变量计算梯度?如果我改用它会有什么区别argnums=0?另外,我可以使用相同的函数来获取 Hessian 矩阵吗?

我查看了有关它的JAX 帮助部分,但无法弄清楚

0 投票
1 回答
56 浏览

tensorflow - 重新实现 bert-style pooler 会引发形状错误,就好像仍然需要长度维度一样

我已经训练了一个现成的 Transformer()。

现在我想使用编码器来构建分类器。为此,我只想使用第一个令牌的输出(bert-style cls-token-result)并通过密集层运行它。

我所做的:

形状: 编码 器 给我形状(64、50、512)64 = batch_size, 50 = seq_len, 512 = model_dim

pooler 为我提供了符合预期和期望的形状(64, 512)。

密集层应该为每个批次成员采用 512 个维度并分类超过 7 个类。但我猜 trax/jax 仍然希望它的长度为 seq_len (50)。

我想念什么?

完整追溯:

0 投票
1 回答
325 浏览

python - JAX vmap 行为

我试图了解 JAX vmap 的行为,所以我编写了以下代码:

输出是:

我知道唯一被改变的输入是b,但是有人可以解释为什么会这样吗?在我对函数进行矢量化后,点积的行为如何?

0 投票
1 回答
43 浏览

gpu - 分析 JAX 代码:什么是 redzone_checker,为什么要花这么多时间?

我找到了这篇文章,但仍不清楚redzone_checker内核在做什么以及为什么。具体来说,它是否应该占用我应用程序运行时间的 90% 以上?TensorBoard 报告说它占用了我的 JAX 代码的绝大部分运行时间,我想知道

  1. 实际上是这个内核花费了太多时间,还是这是使用 TensorBoard 分析 JAX 的副作用(即,输出在某种程度上具有误导性)?
  2. 有没有办法减少redzone_checker内核花费的时间?这甚至是个好主意吗?

提前感谢您的任何见解。

0 投票
0 回答
1129 浏览

python-3.x - RuntimeError:未实现:未找到 DNN 库

我试图通过谷歌实现视觉转换器,我在推理过程中遇到了这个错误:

我关注了这篇文章,但没有帮助。我应该怎么办?

0 投票
1 回答
117 浏览

python - 如何在 jit 编译的 jax 代码中执行非整数索引运算?

如果我们对数组索引执行非整数计算(然后转换为 int() ),似乎我们仍然无法将结果用作 jit 编译的 jax 代码中的有效索引。我们如何解决这个问题?

以下是一个最小的示例。具体问题:命令 jnp.diag_indices(d) 是否可以在不向 fun() 传递额外参数的情况下工作

在木星单元中运行它: