5

您好,我在 C+ 中有这个循环,我试图将其转换为推力,但没有得到相同的结果......有什么想法吗?谢谢你

C++ 代码

for (i=0;i<n;i++) 
    for (j=0;j<n;j++) 
      values[i]=values[i]+(binv[i*n+j]*d[j]);

推力代码

thrust::fill(values.begin(), values.end(), 0);
thrust::transform(make_zip_iterator(make_tuple(
                thrust::make_permutation_iterator(values.begin(), thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexDivFunctor(n))),
                binv.begin(),
                thrust::make_permutation_iterator(d.begin(), thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexModFunctor(n))))),
                make_zip_iterator(make_tuple(
                thrust::make_permutation_iterator(values.begin(), thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexDivFunctor(n))) + n,
                binv.end(),
                thrust::make_permutation_iterator(d.begin(), thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexModFunctor(n))) + n)),
                thrust::make_permutation_iterator(values.begin(), thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexDivFunctor(n))),
                function1()
                );

推力函数

struct IndexDivFunctor: thrust::unary_function<int, int>
{
  int n;

  IndexDivFunctor(int n_) : n(n_) {}

  __host__ __device__
  int operator()(int idx)
  {
    return idx / n;
  }
};

struct IndexModFunctor: thrust::unary_function<int, int>
{
  int n;

  IndexModFunctor(int n_) : n(n_) {}

  __host__ __device__
  int operator()(int idx)
  {
    return idx % n;
  }
};


struct function1
{
  template <typename Tuple>
  __host__ __device__
  double operator()(Tuple v)
  {
    return thrust::get<0>(v) + thrust::get<1>(v) * thrust::get<2>(v);
  }
};
4

2 回答 2

4

首先,一些一般性评论。你的循环

for (i=0;i<n;i++) 
    for (j=0;j<n;j++) 
      v[i]=v[i]+(B[i*n+j]*d[j]);

相当于标准的BLAS gemv操作

在此处输入图像描述

其中矩阵以行主要顺序存储。在设备上执行此操作的最佳方法是使用 CUBLAS,而不是由推力基元构造的东西。

话虽如此,您发布的推力代码绝对不会像您的序列代码那样做。您看到的错误不是浮点关联性的结果。从根本thrust::transform上将提供的函子应用于输入迭代器的每个元素,并将结果存储在输出迭代器上。要产生与您发布的循环相同的结果,该thrust::transform调用需要对您发布的 fmad 仿函数执行 (n*n) 次操作。显然不是。此外,不能保证thrust::transform会以不会出现内存竞争的方式执行求和/减少操作。

正确的解决方案可能是这样的:

  1. 使用推力::transform 计算Bd元素的 (n*n) 乘积
  2. 使用推力::reduce_by_key 将产品减少为部分和,产生Bd
  3. 使用推力::transform 将得到的矩阵向量乘积添加到v以产生最终结果。

在代码中,首先定义一个这样的仿函数:

struct functor
{
  template <typename Tuple>
  __host__ __device__
  double operator()(Tuple v)
  {
    return thrust::get<0>(v) * thrust::get<1>(v);
  }
};

然后执行以下操作来计算矩阵向量乘法

  typedef thrust::device_vector<int> iVec;
  typedef thrust::device_vector<double> dVec;

  typedef thrust::counting_iterator<int> countIt;
  typedef thrust::transform_iterator<IndexDivFunctor, countIt> columnIt;
  typedef thrust::transform_iterator<IndexModFunctor, countIt> rowIt;

  // Assuming the following allocations on the device
  dVec B(n*n), v(n), d(n);

  // transformation iterators mapping to vector rows and columns
  columnIt cv_begin = thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexDivFunctor(n));
  columnIt cv_end   = cv_begin + (n*n);

  rowIt rv_begin = thrust::make_transform_iterator(thrust::make_counting_iterator(0), IndexModFunctor(n));
  rowIt rv_end   = rv_begin + (n*n);

  dVec temp(n*n);
  thrust::transform(make_zip_iterator(
                      make_tuple(
                        B.begin(),
                        thrust::make_permutation_iterator(d.begin(),rv_begin) ) ),
                    make_zip_iterator(
                      make_tuple(
                        B.end(),
                        thrust::make_permutation_iterator(d.end(),rv_end) ) ),
                    temp.begin(),
                    functor());

  iVec outkey(n);
  dVec Bd(n);
  thrust::reduce_by_key(cv_begin, cv_end, temp.begin(), outkey.begin(), Bd.begin());
  thrust::transform(v.begin(), v.end(), Bd.begin(), v.begin(), thrust::plus<double>());

当然,与使用专门设计的矩阵向量乘法代码(如dgemvCUBLAS)相比,这是一种非常低效的计算方式。

于 2011-10-09T12:06:23.457 回答
0

你的结果有多大不同?这是一个完全不同的答案,还是仅在最后一位数字上有所不同?循环只执行一次,还是某种迭代过程?

由于精度问题,浮点运算,尤其是那些重复相加或乘以某些值的运算,不是关联的。此外,如果您使用快速数学优化,这些操作可能不是 IEEE 兼容的。

对于初学者,请查看有关浮点数的维基百科部分:http ://en.wikipedia.org/wiki/Floating_point#Accuracy_problems

于 2011-10-05T10:08:01.767 回答