这是一个示例来展示@Shai使用查找表的想法:
% build lookup table for 8-bit integers
lut = sum(dec2bin(0:255)-'0', 2);
% get indices
idx = find(mlf);
% break indices into 8-bit integers and apply LUT
nbits = lut(double(typecast(uint32(idx),'uint8')) + 1);
% sum number of bits in each
s = sum(reshape(nbits,4,[]))
uint64
如果您有非常大的稀疏数组且索引超出 32 位范围,则可能需要切换到..
编辑:
这是使用 Java 的另一种解决方案:
idx = find(mlf);
s = arrayfun(@java.lang.Integer.bitCount, idx);
编辑#2:
这是另一种实现为 C++ MEX 函数的解决方案。它依赖于std::bitset::count
:
bitset_count.cpp
#include "mex.h"
#include <bitset>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
// validate input/output arguments
if (nrhs != 1) {
mexErrMsgTxt("One input argument required.");
}
if (!mxIsUint32(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0])) {
mexErrMsgTxt("Input must be a 32-bit integer dense matrix.");
}
if (nlhs > 1) {
mexErrMsgTxt("Too many output arguments.");
}
// create output array
mwSize N = mxGetNumberOfElements(prhs[0]);
plhs[0] = mxCreateDoubleMatrix(N, 1, mxREAL);
// get pointers to data
double *counts = mxGetPr(plhs[0]);
uint32_T *idx = reinterpret_cast<uint32_T*>(mxGetData(prhs[0]));
// count bits set for each 32-bit integer number
for(mwSize i=0; i<N; i++) {
std::bitset<32> bs(idx[i]);
counts[i] = bs.count();
}
}
将上述函数编译为mex -largeArrayDims bitset_count.cpp
,然后照常运行:
idx = find(mlf);
s = bitset_count(uint32(idx))
我决定比较到目前为止提到的所有解决方案:
function [t,v] = testBitsetCount()
% random data (uint32 vector)
x = randi(intmax('uint32'), [1e5,1], 'uint32');
% build lookup table (done once)
LUT = sum(dec2bin(0:255,8)-'0', 2);
% functions to compare
f = {
@() bit_twiddling(x) % bit twiddling method
@() lookup_table(x,LUT); % lookup table method
@() bitset_count(x); % MEX-function (std::bitset::count)
@() dec_to_bin(x); % dec2bin
@() java_bitcount(x); % Java Integer.bitCount
};
% compare timings and check results are valid
t = cellfun(@timeit, f, 'UniformOutput',true);
v = cellfun(@feval, f, 'UniformOutput',false);
assert(isequal(v{:}));
end
function s = lookup_table(x,LUT)
s = sum(reshape(LUT(double(typecast(x,'uint8'))+1),4,[]))';
end
function s = dec_to_bin(x)
s = sum(dec2bin(x,32)-'0', 2);
end
function s = java_bitcount(x)
s = arrayfun(@java.lang.Integer.bitCount, x);
end
function s = bit_twiddling(x)
p1 = uint32(1431655765);
p2 = uint32(858993459);
p3 = uint32(252645135);
p4 = uint32(16711935);
p5 = uint32(65535);
s = x;
s = bitand(bitshift(s, -1), p1) + bitand(s, p1);
s = bitand(bitshift(s, -2), p2) + bitand(s, p2);
s = bitand(bitshift(s, -4), p3) + bitand(s, p3);
s = bitand(bitshift(s, -8), p4) + bitand(s, p4);
s = bitand(bitshift(s,-16), p5) + bitand(s, p5);
end
以秒为单位的时间:
t =
0.0009 % bit twiddling method
0.0087 % lookup table method
0.0134 % C++ std::bitset::count
0.1946 % MATLAB dec2bin
0.2343 % Java Integer.bitCount