这似乎可行,其他人可能有更好的想法:
#include <ostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/transform.h>
#include <thrust/functional.h>
#include <thrust/copy.h>
#include <thrust/fill.h>
#define DSIZE 10
template <typename T>
thrust::device_vector<T> operator+(thrust::device_vector<T> &lhs, const thrust::device_vector<T> &rhs) {
thrust::transform(rhs.begin(), rhs.end(),
lhs.begin(), lhs.begin(), thrust::plus<T>());
return lhs;
}
template <typename T>
thrust::host_vector<T> operator+(thrust::host_vector<T> &lhs, const thrust::host_vector<T> &rhs) {
thrust::transform(rhs.begin(), rhs.end(),
lhs.begin(), lhs.begin(), thrust::plus<T>());
return lhs;
}
int main() {
thrust::device_vector<float> dvec(DSIZE);
thrust::device_vector<float> otherdvec(DSIZE);
thrust::fill(dvec.begin(), dvec.end(), 1.0f);
thrust::fill(otherdvec.begin(), otherdvec.end(), 2.0f);
thrust::host_vector<float> hresult1 = dvec + otherdvec;
std::cout << "result 1: ";
thrust::copy(hresult1.begin(), hresult1.end(), std::ostream_iterator<float>(std::cout, " ")); std::cout << std::endl;
thrust::host_vector<float> hvec(DSIZE);
thrust::fill(hvec.begin(), hvec.end(), 5.0f);
thrust::host_vector<float> hresult2 = hvec + hresult1;
std::cout << "result 2: ";
thrust::copy(hresult2.begin(), hresult2.end(), std::ostream_iterator<float>(std::cout, " ")); std::cout << std::endl;
// this line would produce a compile error:
// thrust::host_vector<float> hresult3 = dvec + hvec;
return 0;
}
请注意,无论哪种情况,我都可以为结果指定主机或设备向量,因为推力会看到差异并自动生成必要的复制操作。因此,我的模板中的结果向量类型(主机、设备)并不重要。
另请注意,thrust::transform
您在模板定义中的函数参数并不完全正确。