我有一个代码,我从Numpy repeat 获得的二维数组
下面的一个适用于 numpy 数组,但会抛出 ValueError:具有多个元素的数组的真值是不明确的。与 cupy 数组一起使用时使用 a.any() 或 a.all()。对于线路ret_val[mask] = cp.repeat(arr.ravel(), rep.ravel()
我尝试使用cupy中已经存在的逻辑操作,但它们仍然会抛出错误。
def repeat2dvect(arr, rep):
lens = cp.array(rep.sum(axis=-1))
maxlen = lens.max()
ret_val = cp.zeros((arr.shape[0], int(maxlen)))
mask = (lens[:,None]>cp.arange(maxlen))
ret_val[mask] = cp.repeat(arr.ravel(), rep.ravel())
return ret_val