1

对于我的批次中的某些矩阵,由于矩阵是奇异的,我遇到了异常。

L = th.cholesky(Xt.bmm(X))

cholesky_cpu:对于批次 51100:U(22,22) 为零,单数 U

由于它们对于我的用例来说很少,我想忽略异常并进一步处理它们。我将结果计算设置为nan有可能以某种方式吗?

实际上,如果我仍然catch使用异常并使用continue它并没有完成其余批次的计算。

在带有 Pytorch libtorch 的 C++ 中也会发生同样的情况。

4

3 回答 3

1

在执行 cholesky 分解时,PyTorch 依赖于 CPU 张量的 LAPACK 和 CUDA 张量的 MAGMA。在用于调用 LAPACK 的 PyTorch 代码中,批处理只是迭代,zpotrs_分别在每个矩阵上调用 LAPACK 的函数。在用于调用 MAGMA 的 PyTorch 代码中,整个批次都是使用 MAGMA 处理的magma_dpotrs_batched,这可能比分别迭代每个矩阵要快。

AFAIK 没有办法指示 MAGMA 或 LAPACK 不引发异常(但公平地说,我不是这些软件包的专家)。由于 MAGMA 可能会以某种方式利用批处理,我们可能不想只默认使用迭代方法,因为我们可能会因不执行批处理的 cholesky 而失去性能。

一种可能的解决方案是首先尝试执行批量 Cholesky 分解,如果失败,那么我们可以对批处理中的每个元素执行 Cholesky 分解,将失败的条目设置为 NaN。

def cholesky_no_except(x, upper=False, force_iterative=False):
    success = False
    if not force_iterative:
        try:
            results = torch.cholesky(x, upper=upper)
            success = True
        except RuntimeError:
            pass

    if not success:
        # fall back to operating on each element separately
        results_list = []
        x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
        for batch_idx in range(x_batched.shape[0]):
            try:
                result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
            except RuntimeError:
                # may want to only accept certain RuntimeErrors add a check here if that's the case
                # on failure create a "nan" matrix
                result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
            results_list.append(result)
        results = torch.cat(results_list, dim=0).reshape(*x.shape)

    return results

如果您希望在 cholesky 分解期间异常常见,您可能希望force_iterative=True跳过尝试使用批处理版本的初始调用,因为在这种情况下,此函数可能只是在第一次尝试时浪费时间。

于 2020-02-15T02:05:12.380 回答
1

根据Pytorch Discuss论坛无法捕获异常。

不幸的是,解决方案是实现我自己的 简单批处理 Cholesky ( th.cholesky(..., upper=False)),然后使用th.isnan.

import torch as th

# nograd cholesky
def cholesky(A):
    L = th.zeros_like(A)

    for i in range(A.shape[-1]):
        for j in range(i+1):
            s = 0.0
            for k in range(j):
                s = s + L[...,i,k] * L[...,j,k]

            L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
                      (1.0 / L[...,j,j] * (A[...,i,j] - s))
    return L
于 2020-04-24T11:51:24.590 回答
1

我不知道这与发布的其他解决方案的速度相比如何,但它可能会更快。

首先用于torch.det确定您的批次中是否有任何奇异矩阵。然后屏蔽掉这些矩阵。

output = Xt.bmm(X)
dets = torch.det(output)

# if output is of shape (bs, x, y), dets will be of shape (bs)
bad_idxs = dets==0 #might want an allclose here

output[bad_idxs] = 1. # fill singular matrices with 1s

L = torch.cholesky(output)

在您可能需要处理用 1 填充的奇异矩阵之后,但您有它们的索引值,因此很容易抓住它们或排除它们。

于 2020-05-10T05:43:58.067 回答