4

我正在使用 ND4j(目前版本 1.0.0-beta5)开发一个严重依赖于 Java 数组操作的科学应用程序。在我的整个管道中,我需要动态选择 [2,195102] 矩阵的非连续子集(更准确地说是几十/几百列)。知道如何在这个框架中实现这一点吗?

简而言之,我正在尝试实现这个 python/numpy 操作:

import numpy as np
arrayData = np.array([[1, 5, 0, 6, 2, 0, 9, 0, 5, 2],
       [3, 6, 1, 0, 4, 3, 1, 4, 8, 1]])
arrayIndex = np.array((1,5,6))
res  = arrayData[:, arrayIndex]
# res value is
# array([[5, 0, 9],
#        [6, 3, 1]])

到目前为止,我设法使用NDArray.getColumns函数选择了所需的列(以及 indexArray 中的 NDArray.data().asInt() 以提供索引的值)。问题是文档明确指出,关于在计算期间检索信息,“请注意,这不应该用于速度”(请参阅​​ NDArray.ToIntMatrix()的文档以查看完整消息 - 不同的方法,相同的操作)。

我查看了NDArray.get()的不同原型,似乎没有一个符合要求。我想NDArray.getWhere()可能会起作用——如果它,正如我所假设的那样,只返回满足条件的元素——但到目前为止,它没有成功使用它。在解释所需的参数/用法时,文档相对较少。

谢谢大家的时间和帮助:)

编辑(2019 年 4 月 11 日):关于我尝试过的一些精确度。我玩弄了 NDArray.get() 并使用了索引:

INDArray arrayData = Nd4j.create(new int[]
                    {1, 5, 0, 6, 2, 0, 9, 0, 5, 2,
                     3, 6, 1, 0, 4, 3, 1, 4, 8, 1},   new long[]{2, 10}, DataType.INT);
INDArray arrayIndex = Nd4j.create(new int[]{1, 5, 6}, new long[]{1,  3}, DataType.INT);

INDArray colSelection = null;

//index free version
colSelection = arrayData.getColumns(arrayIndex.toIntVector());
/*
* colSelection value is
* [[5, 0, 9],
*  [6, 3, 1]]
* but the toIntVector() call pulls the data from the back-end storage
* and re-inject them. That is presumed to be slow.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 4001 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 5339 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 7016 ms for 100000 iterations
*/

//index version
colSelection = arrayData.get(NDArrayIndex.all(), NDArrayIndex.indices(arrayIndex.toLongVector()));
/*
* Same result, but same problem regarding toLongVector() this time around.
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 3200 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 4269 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 5252 ms for 100000 iterations
*/

//weird but functional version (that I just discovered)
colSelection = arrayData.transpose().get(arrayIndex); // the transpose operation is necessary to not hit an IllegalArgumentException: Illegal slice 5
// note that transposing the arrayIndex leads to an IllegalArgumentException: Illegal slice 6 (as it is trying to select the element at the line idx 1, column 5, depth 6, which does not exist)
/*
* colSelection value is
* [5, 6, 0, 3, 9, 1]
* The array is flattened... calling a reshape(arrayData.shape()[0],arrayIndex.shape()[1]) yields
* [[5, 6, 0],
*  [3, 9, 1]]
* which is wrong.
*/
colSelection = colSelection.reshape(arrayIndex.shape()[1],arrayData.shape()[0]).transpose();
/* yields the right result
* [[5, 0, 9],
*  [6, 3, 1]]
* While this seems to be the correct way to handle the memory the performance are low:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 8225 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 8980 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 9453 ms for 100000 iterations
Plus, this is very roundabout method for such a "simple" operation
* if the repacking of the data is commented out, the timing become:
*  -   2 columns selected (arrayIndex = {1, 5}),        ==> 6987 ms for 100000 iterations
*  -   3 columns selected (arrayIndex = {1, 5, 6}),     ==> 7976 ms for 100000 iterations
*  -   4 columns selected (arrayIndex = {1, 5, 6 ,2}),  ==> 8336 ms for 100000 iterations
*/

在不知道我正在运行什么机器的情况下,这些速度似乎还不错,但是等效的 python 代码会产生:

  • 选择 2 列(arrayIndex = {1, 5}),==> 171 毫秒,100000 次迭代
  • 选择了 3 列(arrayIndex = {1, 5, 6}),==> 173 毫秒,100000 次迭代
  • 选择了 4 列(arrayIndex = {1, 5, 6 ,2}),==> 173 毫秒,100000 次迭代

这些 java 实现最多比 python-numpy 慢 20 倍。

4

1 回答 1

2
org.nd4j.linalg.api.ndarray.INDArray arr = org.nd4j.linalg.factory.Nd4j.create(new double[][]{
                {1, 5, 0, 6, 2, 0, 9, 0, 5, 2},
                {3, 6, 1, 0, 4, 3, 1, 4, 8, 1}
        });

        org.nd4j.linalg.indexing.INDArrayIndex indices[] = {
                org.nd4j.linalg.indexing.NDArrayIndex.all(),
                new org.nd4j.linalg.indexing.SpecifiedIndex(1,5,6)
        };

        org.nd4j.linalg.api.ndarray.INDArray selected = arr.get(indices);
        System.out.println(selected);
    }

这应该适合你。这将打印:SLF4J:无法加载类“org.slf4j.impl.StaticLoggerBinder”。SLF4J:默认为无操作 (NOP) 记录器实现 SLF4J:有关详细信息,请参阅http://www.slf4j.org/codes.html#StaticLoggerBinder

[[    5.0000,         0,    9.0000], 
 [    6.0000,    3.0000,    1.0000]]

进程以退出代码 0 结束

于 2019-11-04T04:07:28.427 回答