对于我的批次中的某些矩阵,由于矩阵是奇异的,我遇到了异常。
L = th.cholesky(Xt.bmm(X))
cholesky_cpu:对于批次 51100:U(22,22) 为零,单数 U
由于它们对于我的用例来说很少,我想忽略异常并进一步处理它们。我将结果计算设置为nan有可能以某种方式吗?
实际上,如果我仍然catch
使用异常并使用continue
它并没有完成其余批次的计算。
在带有 Pytorch libtorch 的 C++ 中也会发生同样的情况。
在执行 cholesky 分解时,PyTorch 依赖于 CPU 张量的 LAPACK 和 CUDA 张量的 MAGMA。在用于调用 LAPACK 的 PyTorch 代码中,批处理只是迭代,zpotrs_
分别在每个矩阵上调用 LAPACK 的函数。在用于调用 MAGMA 的 PyTorch 代码中,整个批次都是使用 MAGMA 处理的magma_dpotrs_batched
,这可能比分别迭代每个矩阵要快。
AFAIK 没有办法指示 MAGMA 或 LAPACK 不引发异常(但公平地说,我不是这些软件包的专家)。由于 MAGMA 可能会以某种方式利用批处理,我们可能不想只默认使用迭代方法,因为我们可能会因不执行批处理的 cholesky 而失去性能。
一种可能的解决方案是首先尝试执行批量 Cholesky 分解,如果失败,那么我们可以对批处理中的每个元素执行 Cholesky 分解,将失败的条目设置为 NaN。
def cholesky_no_except(x, upper=False, force_iterative=False):
success = False
if not force_iterative:
try:
results = torch.cholesky(x, upper=upper)
success = True
except RuntimeError:
pass
if not success:
# fall back to operating on each element separately
results_list = []
x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
for batch_idx in range(x_batched.shape[0]):
try:
result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
except RuntimeError:
# may want to only accept certain RuntimeErrors add a check here if that's the case
# on failure create a "nan" matrix
result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
results_list.append(result)
results = torch.cat(results_list, dim=0).reshape(*x.shape)
return results
如果您希望在 cholesky 分解期间异常常见,您可能希望force_iterative=True
跳过尝试使用批处理版本的初始调用,因为在这种情况下,此函数可能只是在第一次尝试时浪费时间。
根据Pytorch Discuss
论坛无法捕获异常。
不幸的是,解决方案是实现我自己的 简单批处理 Cholesky ( th.cholesky(..., upper=False)
),然后使用th.isnan
.
import torch as th
# nograd cholesky
def cholesky(A):
L = th.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i+1):
s = 0.0
for k in range(j):
s = s + L[...,i,k] * L[...,j,k]
L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
(1.0 / L[...,j,j] * (A[...,i,j] - s))
return L
我不知道这与发布的其他解决方案的速度相比如何,但它可能会更快。
首先用于torch.det
确定您的批次中是否有任何奇异矩阵。然后屏蔽掉这些矩阵。
output = Xt.bmm(X)
dets = torch.det(output)
# if output is of shape (bs, x, y), dets will be of shape (bs)
bad_idxs = dets==0 #might want an allclose here
output[bad_idxs] = 1. # fill singular matrices with 1s
L = torch.cholesky(output)
在您可能需要处理用 1 填充的奇异矩阵之后,但您有它们的索引值,因此很容易抓住它们或排除它们。