1

我想在 Java 中对大型稀疏矩阵进行 Cholesky 分解。目前我正在使用Parallel Colt库类SparseDoubleCholeskyDecomposition但它比使用我在 C 中编写的密集矩阵代码要慢得多,我在 java 中使用 JNI

例如,对于 5570x5570 的非零密度为 0.25% 的矩阵,使用 SparseDoubleCholeskyDecomposition 需要26.6 秒来分解,而我自己的使用密集存储的相同矩阵的代码只需要1.12 秒。但是,如果我将密度设置为 0.025%,那么 colt 库只需要 0.13 秒。

我还使用压缩行存储和 OpenMP 在 C 中编写了我自己的稀疏矩阵 Cholesky 分解,它也比 SparseDoubleCholeskyDecomposition 快很多,但仍然比我使用密集存储的算法慢。我可能可以进一步优化它,但无论如何 Cholesky 因子都很密集,因此它既不能解决速度问题,也不能解决存储问题。

但是我想要毫秒级的时间,我最终需要扩展到超过 10000x10000 的矩阵,所以即使是我自己的密集矩阵的密集代码也会变得太慢并且使用太多内存。

我有理由相信稀疏矩阵的 Cholesky 分解可以更快地完成并且使用更少的内存。也许我需要一个更好的 Java BLAS 库?是否有 Java 库可以有效地对稀疏矩阵进行 Cholesky 分解?也许我没有以最佳方式使用 Parallel Colt?

import cern.colt.matrix.tdouble.algo.decomposition.SparseDoubleCholeskyDecomposition;
import cern.colt.matrix.tdouble.impl.*;
import java.util.Random;

public class Cholesky {
    static SparseRCDoubleMatrix2D create_sparse(double[] a, int n, double r) {
        SparseRCDoubleMatrix2D result = new SparseRCDoubleMatrix2D(n, n);
        Random rand = new Random();
        for(int i=0; i<n; i++) {
            for(int j=i; j<n; j++) {
                double c = rand.nextDouble();
                double element = c<r ? rand.nextDouble() : 0;
                a[i*n+j] = element;
                a[j*n+i] = element;
                if (element != 0) {
                    result.setQuick(i, j, element);
                    result.setQuick(j, i, element);
                }
            }
        }
        for(int i=0; i<n; i++) {
            a[i * n + i] += n;
            result.setQuick(i,i, result.getQuick(i,i) + n);
        }
        return result;
    }

    public static void main(String[] args) {
        int n = 5570;
        //int n = 2048;
        double[] a = new double[n*n];
        SparseRCDoubleMatrix2D sparseMatrix = create_sparse(a, n, 0.0025);

        long startTime, endTime, duration;

        startTime = System.nanoTime();
        SparseDoubleCholeskyDecomposition sparseDoubleCholeskyDecomposition = new SparseDoubleCholeskyDecomposition(sparseMatrix, 0);
        endTime = System.nanoTime();
        duration = (endTime - startTime);
        System.out.printf("colt time construct %.2f s\n", 1E-9*duration);
    }
}
4

0 回答 0