0

我在使用 Flux / Zygote 时遇到了一个错误,它通过快速谷歌搜索似乎与“CUDA”/“Matrix Product State”相关。可以使用以下代码在 Julia 的最新稳定版本(编写时为 1.5.3)中复制它:

using Flux, Zygote

f(x,y) = sum(sum(batched_mul(x, y), dims=(1,2)));
g(x,y) = gradient((X_, Y_) -> f(X_, Y_), x, y)[1];
g(rand(1,1,1), rand(1,2,1))

# Result:
#  ** On entry to DGEMM  parameter number 10 had an illegal value
# 1×1×1 Array{Float64,3}:
# [:, :, 1] =
#  5.0e-324

我不确定引擎盖下发生了什么。我只能说“某事”出了问题,但它仍然会吐出看起来有效的结果。它还会在我的应用程序中向控制台发送垃圾邮件,以调整每个批次每个批次 1** On entry to DGEMM parameter number 10 had an illegal value个元素...使打印到控制台成为应用程序的瓶颈。

到底是怎么回事?我该如何解决?是否值得修复,还是我可以忽略它?如果忽略,我怎样才能阻止我的屏幕被这些警告/错误淹没?

编辑

我找到了一种使用OMEinsum的解决方法,它可以完成相同的工作,但不会向控制台发送错误消息:

@ein f(x,y)[i,k,l] := x[i,j,l] * y[j,k,l];
g(x,y) = gradient((X_, Y_) -> sum(f(X_, Y_)), x, y);
g(rand(2,2,2), rand(2,2,2))

# Result:
# ([0.9722633852326483 0.33601882819991724; 0.9722633852326483 0.33601882819991724]
# 
# [1.9523351912466416 1.0298638932648905; 1.9523351912466416 1.0298638932648905], [1.833517766235997 1.833517766235997; 1.2580222064250244 1.2580222064250244]
# 
# [0.4056799727986937 0.4056799727986937; 0.7334134199010598 0.7334134199010598])

不确定 OMEinsum 是否是这方面的“最佳”张量库(对张量来说是新手,可用库的数量令人眼花缭乱),但它对我来说是最容易理解的。

我没有将此作为答案,因为从根本上说,该错误batched_mul尚未修复。

4

0 回答 0