问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
1 回答
306 浏览

python - jax vmap:强制执行正确的形状

我正在使用vmap对部分代码进行矢量化处理。这是一个最小的例子,在矢量化之前:

使用 vmap:

但这可能会出错,在矢量化之后:

这里发生的是samples[0]具有形状的(2,)。矢量化函数调用沿第一个轴拆分其输入参数,因此输入 2 个 shape 数组(1,)。由于使用 广播a,结果输出再次具有形状(2,)并堆叠到(2,2)数组中。

这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他一些隐藏其损坏形状的广播规则所消耗。

是否可以强制执行正确的形状?

0 投票
1 回答
705 浏览

python - pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?

最初,tensorflow 和 pytorch 有一个根本的区别:

  • tensorflow 基于计算图。构建此图并在会话中对其进行评估是两个独立的步骤。当它被使用时,图表不会改变,这允许优化。
  • torch 急切地评估张量上的操作。这使得 API 更方便(无会话),但也失去了识别和优化总是按顺序发生的操作的潜力。

现在,这种差异变得不那么明显了。Tensorflow 通过tf eager回应了火炬的流行。还有一个JAX项目,它建立在与 tensorflow ( XLA )相同的底层框架上。JAX 没有会话的概念。但它允许您通过简单地调用jit将多个操作一起编译。

由于 Tensorflow 已经开始涵盖 PyTorch 功能,PyTorch 是否也在努力整合 Tensorflow 的优势?PyTorch(或其路线图)中是否有类似会话或 jit 功能的东西?

API 文档有一个jit 部分,但据我所知,这更多是关于导出模型。

0 投票
1 回答
754 浏览

python - JAX:jit 函数的时间随着函数访问的内存而超线性增长

这是一个简单的例子,它对两个高斯 pdf 的乘积进行数值积分。其中一个高斯是固定的,均值始终为 0。另一个高斯的均值不同:

该函数看起来很简单,但它根本无法扩展。以下列表包含成对的(integr_resolution 的值,运行代码所用的时间):

  • 100 | 0.107s
  • 200 | 0.23s
  • 400 | 0.537s
  • 800 | 1.52s
  • 1600 | 5.2s
  • 3200 | 19 岁
  • 6400 | 134s

作为参考,应用到的 unjitted 函数integr_resolution=6400需要 0.02s。

我认为这可能与函数正在访问全局变量这一事实有关。但是移动代码以在函数内部设置积分点对时序没有显着影响。以下代码需要 5.36 秒才能运行。它对应于先前花费 5.2 秒的 1600 表条目:

这里发生了什么?

0 投票
1 回答
1594 浏览

python - 使用 Python JAX/Autograd 的向量值函数的雅可比行列式

我有一个将向量映射到向量的函数

f: R^n -> R^n

我想计算它的雅可比行列式

det J = |df/dx|,

其中雅可比定义为

J_ij = |df_i/dx_j|.

因为我可以使用numpy.linalg.det, 来计算行列式,所以我只需要雅可比矩阵。我知道numdifftools.Jacobian,但这使用数值微分,我在自动微分之后。输入Autograd/ JAX(我暂时坚持Autograd,它有一个autograd.jacobian()方法,但我很乐意使用JAX,只要我得到我想要的)。如何autograd.jacobian()正确使用这个函数和向量值函数?

作为一个简单的例子,我们来看看函数

![f(x)=(x_0^2, x_1^2)]( https://chart.googleapis.com/chart?cht=tx&chl=f(x%29%20%3D%20(x_0%5E2% 2C%20x_1%5E2%29 )

具有雅可比行列式

![J_f = diag(2 x_0, 2 x_1)]( https://chart.googleapis.com/chart?cht=tx&chl=J_f%20%3D%20%5Cooperatorname%7Bdiag%7D(2x_0%2C%202x_1% 29 )

导致雅可比行列式

det J_f = 4 x_0 x_1

第一种方法给了我正确的值,但形状错误。为什么会.jacobian()返回这样一个嵌套数组?如果我正确地重塑它,我会得到正确的结果:

但是现在让我们看看这在数组广播中是如何工作的,当我尝试评估多个值的雅​​可比行列式时x

显然,这两种形状都是错误的,正确的(如我想要的雅可比矩阵)woule be

shape=(6,2,2)

我需要如何使用autograd.jacobian(或jax.jacfwd/ jax.jacrev)才能使其正确处理多个向量输入?


注意:使用显式循环并手动处理每个点,我得到了正确的结果。但是有没有办法做到这一点?

0 投票
2 回答
1132 浏览

python - Jax 的 JIT 和 Numpy 限制问题

我最近开始尝试有趣的 python 库Jax,其中包含增强的 Numpy 和自动微分器。我想尝试创建的是一个粗略的“可微渲染器”,通过在 python 中编写着色器和损失函数,然后使用 Jax 的 AD 来查找渐变。然后我们应该能够通过在这个损失梯度上运行梯度下降来逆向渲染图像。我用简单的着色器让它工作得很好,但是当我使用布尔表达式时我遇到了问题。这是我的着色器的代码,它生成一个棋盘图案:

这是我的渲染函数,这是 JIT 失败的第一个地方:

我收到的错误信息是:

我猜这是因为您不能在具有布尔值的函数上运行 JIT,其值取决于编译时未确定的内容。但是我怎样才能重写它以使用 JIT 呢?如果没有 JIT,它会非常缓慢。

我的另一个问题是,我能做些什么来加速 Jax 的 Numpy?使用普通 Numpy 渲染我的图像(100x100 像素)需要几毫秒,但使用 Jax 的 Numpy 需要几秒钟!感谢:D

0 投票
1 回答
903 浏览

tensorflow - 简单来说,JAX、Trax 和 TensorRT 有什么区别?

我一直在使用 TensorRT 和 TensorFlow-TRT 来加速我的 DL 算法的推理。

然后我听说了:

两者似乎都加速了深度学习。但我很难理解他们。任何人都可以简单地解释它们吗?

0 投票
2 回答
128 浏览

numpy - 用 jax 计算逐行(或逐轴)点积的最佳方法是什么?

我有两个形状(N,M)的数值数组。我想计算一个逐行的点积。即产生一个形状为 (N,) 的数组,使得第 n 行是每个数组的第 n 行的点积。

我知道numpy的inner1d方法。使用 jax 执行此操作的最佳方法是什么?jax 有jax.numpy.inner,但这有别的作用。

0 投票
1 回答
412 浏览

python - 查找函数的梯度:Sympy vs. Jax

我有一个Black_Cox()调用其他函数的函数,如下所示:

我需要使用Black_Cox函数 wrt的导数V。更准确地说,我需要在改变其他参数的数千条路径中评估这个导数,找到导数并在 some 处进行评估V

最好的方法是什么?

  • 我是否应该像在 Mathematica 中那样使用sympy来找到这个导数,然后根据我V的选择进行评估:D[BlackCox[V, 10, 100, 160], V] /. V -> 180,或者

  • 我应该使用jax吗?

如果sympy,你会如何建议我这样做?

jax我了解,我需要执行以下导入:

并在获得渐变之前重新评估我的功能:

如果我仍然需要使用numpy函数的版本,我是否必须为每个函数创建 2 个实例,或者是否有一种优雅的方式来复制函数jax

0 投票
1 回答
1223 浏览

python - vmap 在 jax 中的列表上

使用 jax,我尝试计算每个样本的梯度,处理它们,然后将它们带入正常形式以计算正常参数更新。我的工作代码看起来像

gradients的形式在哪里list(tuple(DeviceArray(...), DeviceArray(...)), ...)

现在我尝试将循环重写为 vmap (不确定它是否最终会带来加速)

sum_samples只调用一次,而不是为列表中的每个元素调用。

列表是问题还是我理解其他错误?

0 投票
1 回答
335 浏览

python - Google JAX 1D 卷积神经网络

我正在尝试使用stax.GeneralConv() ( https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv ) 在Google Jax中实现一维卷积神经网络。我有一个包含 18 个条目的一维输入数组和一个包含 6 个条目的输出数组。我想实现一个内核宽度为 3 的 CNN,如下所示:

具有初始网络参数:

但我收到以下错误:

stax 要求维度编号rhs_spec至少为 2 个字符长,但我使用一维过滤器。有人知道如何解决这个问题吗?