1

我正在尝试使用 Thrust 减少一组值的最小值和最大值,但我似乎被卡住了。给定一组浮点数,我想要一次减少它们的最小值和最大值,但是使用推力的 reduce 方法,我得到了所有模板编译错误的母亲(或至少是阿姨)。

我的原始代码包含 5 个值列表,分布在 2 个我想要减少的 float4 数组中,但我已将其归结为这个简短的示例。

struct ReduceMinMax {
    __host__ __device__
    float2 operator()(float lhs, float rhs) {
        return make_float2(Min(lhs, rhs), Max(lhs, rhs));
    }
};

int main(int argc, char *argv[]){

    thrust::device_vector<float> hat(4);
    hat[0] = 3;
    hat[1] = 5;
    hat[2] = 6;
    hat[3] = 1;

    ReduceMinMax binary_op_of_dooooom;
    thrust::reduce(hat.begin(), hat.end(), 4.0f, binary_op_of_dooooom);
}

如果我将它分成 2 个缩减而不是它当然有效。那么我的问题是:是否有可能通过推力同时减少最小值和最大值,以及如何减少?如果不是,那么实现上述减少的最有效方法是什么?转换迭代器会帮助我吗(如果是这样,那么减少会是一次减少吗?)

一些附加信息:我正在使用 Thrust 1.5(由 CUDA 4.2.7 提供) 我的实际代码使用的是 reduce_by_key,而不仅仅是 reduce。我在写这个问题时发现了 transform_reduce ,但是这个问题没有考虑到密钥。

4

1 回答 1

4

正如 talonmies 所指出的,您的归约不会编译,因为thrust::reduce期望二元运算符的参数类型与其结果类型匹配,但是ReduceMinMax的参数类型是float,而它的结果类型是float2

thrust::minmax_element直接实现此操作,但如有必要,您可以改为使用 来实现归约thrust::inner_product,这概括了thrust::reduce

#include <thrust/inner_product.h>
#include <thrust/device_vector.h>
#include <thrust/extrema.h>
#include <cassert>

struct minmax_float
{
  __host__ __device__
  float2 operator()(float lhs, float rhs)
  {
    return make_float2(thrust::min(lhs, rhs), thrust::max(lhs, rhs));
  }
};

struct minmax_float2
{
  __host__ __device__
  float2 operator()(float2 lhs, float2 rhs)
  {
    return make_float2(thrust::min(lhs.x, rhs.x), thrust::max(lhs.y, rhs.y));
  }
};

float2 minmax1(const thrust::device_vector<float> &x)
{
  return thrust::inner_product(x.begin(), x.end(), x.begin(), make_float2(4.0, 4.0f), minmax_float2(), minmax_float());
}

float2 minmax2(const thrust::device_vector<float> &x)
{
  using namespace thrust;
  pair<device_vector<float>::const_iterator, device_vector<float>::const_iterator> ptr_to_result;

  ptr_to_result = minmax_element(x.begin(), x.end());

  return make_float2(*ptr_to_result.first, *ptr_to_result.second);
}

int main()
{
  thrust::device_vector<float> hat(4);
  hat[0] = 3;
  hat[1] = 5;
  hat[2] = 6;
  hat[3] = 1;

  float2 result1 = minmax1(hat);
  float2 result2 = minmax2(hat);

  assert(result1.x == result2.x);
  assert(result1.y == result2.y);
}
于 2012-05-10T21:42:14.613 回答