__global__
static void find_groups(int *locs, int *sorted, int num)
{
int bid = blockIdx.y * gridDim.x + blockIdx.x;
int tid = bid * blockDim.x + threadIdx.x;
if (tid < num) {
int curr = sorted[tid];
if (tid == 0 || curr != sorted[tid - 1]) locs[curr] = tid;
}
}
int main()
{
int h_P0[N] = {0, 0, 1, 2, 1, 1, 0, 2, 0, 0};
int h_P1[N] = {0, 1, 1, 2, 1, 2, 0, 2, 1, 0};
thrust::host_vector<int> th_P0(h_P0, h_P0 + N);
thrust::host_vector<int> th_P1(h_P1, h_P1 + N);
thrust::device_vector<int> td_P0 = th_P0;
thrust::device_vector<int> td_P1 = th_P1;
thrust::device_vector<int> td_S0(N);
thrust::device_vector<int> td_S1(N);
thrust::sequence(td_S0.begin(), td_S0.end());
thrust::sequence(td_S1.begin(), td_S1.end());
thrust::stable_sort_by_key(td_P0.begin(), td_P0.end(), td_S0.begin());
thrust::stable_sort_by_key(td_P1.begin(), td_P1.end(), td_S1.begin());
thrust::device_vector<int> td_l0(3, -1); // Changed here
thrust::device_vector<int> td_l1(3, -1); // And here
int threads = 256;
int blocks_x = (N + 256) / 256;
int blocks_y = (blocks_x + 65535) / 65535;
dim3 blocks(blocks_x, blocks_y);
int *d_l0 = thrust::raw_pointer_cast(td_l0.data());
int *d_l1 = thrust::raw_pointer_cast(td_l1.data());
int *d_P0 = thrust::raw_pointer_cast(td_P0.data());
int *d_P1 = thrust::raw_pointer_cast(td_P1.data());
find_groups<<<blocks, threads>>>(d_l0, d_P0, N);
find_groups<<<blocks, threads>>>(d_l1, d_P1, N);
return 0;
}
该算法可以用简单的步骤来解释。
- 按键排序 P0
- 按键排序 P1
- 键现在包含第二个表
现在将 P0 和 P1 传递给 find_groups 内核。由于您知道只有 3 个组,因此组号从 n-1 变为 n 的线程写入全局内存。线程 0 将始终写入 0,因为这是所有向量的第一组的开始。
我试着把它们打印出来。这就是我得到的。请记住,一切都是 0 索引。
Sorted
t t+1
0 0
1 6
6 9
8 1
9 2
2 4
4 8
5 3
3 5
7 7
Ranges
Groups t t + 1
S [0-4] [0-2]
I [5-7] [3-6]
R [8-9] [7-9]
如果您需要访问完整代码(包括打印代码),请访问此链接。
我不确定这是否足够。但是,如果我在这里遗漏了什么,请告诉我。
编辑
更改了代码以处理缺少类的位置。用 -1 初始化相关向量。因此,当您遇到 -1 的起点时,这意味着该类不会出现在该迭代中。