3

我有两个整数数组,dmap 并且dflag 在相同长度的设备上,我用推力设备指针包装了它们,dmapt并且 dflagt

dmap 数组中有一些元素的值为 -1。我想从 dflag 数组中删除这些 -1 和相应的值。

我正在使用 remove_if 函数来执行此操作,但我无法弄清楚这个调用的返回值是什么,或者我应该如何使用这个返回值来获取 .

(我想将这些简化的数组传递给reduce_by_keydflagt 将用作键的函数。)

我正在使用以下调用进行减少。请让我知道如何将返回的值存储在变量中并使用它来处理各个数组dflagdmap

thrust::remove_if( 
    thrust::make_zip_iterator(thrust::make_tuple(dmapt, dflagt)), 
    thrust::make_zip_iterator(thrust::make_tuple(dmapt+numindices, dflagt+numindices)), 
    minus_one_equality_test() 
); 

上面使用的谓词函子定义为

struct minus_one_equality_test
{ 
    typedef typename thrust::tuple<int,int> Tuple; 
    __host__ __device__ 
    bool operator()(const Tuple& a ) 
    { 
        return  thrust::get<0>(a) ==  (-1); 
    } 
} 
4

1 回答 1

6

返回值是一个 zip_iterator,它标记了在 remove_if 调用期间仿函数返回 true 的元组序列的新结尾。要访问底层数组的新结束迭代器,您需要从 zip_iterator 检索元组迭代器;然后,该元组的内容是您用于构建 zip_iterator 的原始数组的新结束迭代器。文字比代码复杂得多:

#include <thrust/tuple.h>
#include <thrust/device_vector.h>
#include <thrust/device_ptr.h>
#include <thrust/remove.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/copy.h>

#include <iostream>

struct minus_one_equality_test
{ 
    typedef thrust::tuple<int,int> Tuple; 
    __host__ __device__ 
    bool operator()(const Tuple& a ) 
    { 
        return  thrust::get<0>(a) ==  (-1); 
    }; 
}; 


int main(void)
{
    const int numindices = 10;

    int mapt[numindices] = { 1, 2, -1, 4, 5, -1, 7, 8, -1, 10 };
    int flagt[numindices] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

    thrust::device_vector<int> vmapt(10);
    thrust::device_vector<int> vflagt(10);

    thrust::copy(mapt, mapt+numindices, vmapt.begin());
    thrust::copy(flagt, flagt+numindices, vflagt.begin());

    thrust::device_ptr<int> dmapt = vmapt.data();
    thrust::device_ptr<int> dflagt = vflagt.data();

    typedef thrust::device_vector< int >::iterator  VIt;
    typedef thrust::tuple< VIt, VIt > TupleIt;
    typedef thrust::zip_iterator< TupleIt >  ZipIt;

    ZipIt Zend = thrust::remove_if(  
        thrust::make_zip_iterator(thrust::make_tuple(dmapt, dflagt)), 
        thrust::make_zip_iterator(thrust::make_tuple(dmapt+numindices, dflagt+numindices)), 
        minus_one_equality_test() 
    ); 

    TupleIt Tend = Zend.get_iterator_tuple();
    VIt vmapt_end = thrust::get<0>(Tend);

    for(VIt x = vmapt.begin(); x != vmapt_end; x++) {
        std::cout << *x << std::endl;
    }

    return 0;
}

如果您编译并运行它,您应该会看到如下内容:

$ nvcc -arch=sm_12 remove_if.cu 
$ ./a.out
1
2
4
5
7
8
10

在此示例中,我仅“检索”元组第一个元素的短内容,第二个元素以相同的方式访问,即。标记向量新端的迭代器是thrust::get<1>(Tend).

于 2012-09-04T20:28:04.757 回答