此示例说明如何使用该reduce_by_key
算法计算每一行的总和。您可以轻松地调整该示例来计算每行的最小值或最大值。要同时计算每行的最小值和最大值,您需要使用此策略。具体来说,您需要transform_iterator
在输入数据上使用 a 并将每个值x
转换为元组(x,x)
,然后再应用minmax_binary_op
归约运算符。
这是一个完整的例子:
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/generate.h>
#include <thrust/transform_reduce.h>
#include <thrust/functional.h>
#include <thrust/extrema.h>
#include <thrust/random.h>
#include <iostream>
#include <iomanip>
// minmax_pair stores the minimum and maximum
// values that have been encountered so far
template <typename T>
struct minmax_pair
{
T min_val;
T max_val;
};
// minmax_unary_op is a functor that takes in a value x and
// returns a minmax_pair whose minimum and maximum values
// are initialized to x.
template <typename T>
struct minmax_unary_op
: public thrust::unary_function< T, minmax_pair<T> >
{
__host__ __device__
minmax_pair<T> operator()(const T& x) const
{
minmax_pair<T> result;
result.min_val = x;
result.max_val = x;
return result;
}
};
// minmax_binary_op is a functor that accepts two minmax_pair
// structs and returns a new minmax_pair whose minimum and
// maximum values are the min() and max() respectively of
// the minimums and maximums of the input pairs
template <typename T>
struct minmax_binary_op
: public thrust::binary_function< minmax_pair<T>,
minmax_pair<T>,
minmax_pair<T> >
{
__host__ __device__
minmax_pair<T> operator()(const minmax_pair<T>& x, const minmax_pair<T>& y) const
{
minmax_pair<T> result;
result.min_val = thrust::min(x.min_val, y.min_val);
result.max_val = thrust::max(x.max_val, y.max_val);
return result;
}
};
// convert a linear index to a row index
template <typename T>
struct linear_index_to_row_index : public thrust::unary_function<T,T>
{
T C; // number of columns
__host__ __device__
linear_index_to_row_index(T C) : C(C) {}
__host__ __device__
T operator()(T i)
{
return i / C;
}
};
int main(void)
{
int R = 5; // number of rows
int C = 8; // number of columns
thrust::default_random_engine rng;
thrust::uniform_int_distribution<int> dist(0, 99);
// initialize data
thrust::device_vector<int> array(R * C);
for (size_t i = 0; i < array.size(); i++)
array[i] = dist(rng);
// allocate storage for per-row results and indices
thrust::device_vector< minmax_pair<int> > row_results(R);
thrust::device_vector< int > row_indices(R);
// compute row sums by summing values with equal row indices
thrust::reduce_by_key
(thrust::make_transform_iterator(thrust::counting_iterator<int>(0), linear_index_to_row_index<int>(C)),
thrust::make_transform_iterator(thrust::counting_iterator<int>(0), linear_index_to_row_index<int>(C)) + (R*C),
thrust::make_transform_iterator(array.begin(), minmax_unary_op<int>()),
row_indices.begin(),
row_results.begin(),
thrust::equal_to<int>(),
minmax_binary_op<int>());
// print data
for(int i = 0; i < R; i++)
{
minmax_pair<int> result = row_results[i];
std::cout << "[";
for(int j = 0; j < C; j++)
std::cout << std::setw(3) << array[i * C + j] << " ";
std::cout << "] = " << "(" << result.min_val << "," << result.max_val << ")\n";
}
return 0;
}
样本输出:
[ 0 8 60 89 96 18 51 39 ] = (0,96)
[ 26 74 8 56 58 80 59 51 ] = (8,80)
[ 87 99 72 96 29 42 89 65 ] = (29,99)
[ 90 96 16 85 90 29 93 41 ] = (16,96)
[ 30 51 39 78 68 54 59 9 ] = (9,78)