我想我遇到了一个 CUDA 错误。有人可以确认/评论代码(见下文)。
代码(附件)将根据“BUG”定义产生不同的结果。BUG=0 结果为 8(正确),而 BUG=1 结果为 4(错误)。代码的区别只在这里:
#if BUG
unsigned int na=threadIdx.x, nb=threadIdx.y, nc=threadIdx.z;
#else
unsigned int na=0, nb=0, nc=0;
#endif
我只提交一个线程,所以 na==nb==nc==0 在这两种情况下,我也用语句检查这个:
assert( na==0 && nb==0 && nc==0 );
printf("INITIAL VALUES: %u %u %u\n",na,nb,nc);
这是我的编译和运行:
nvcc -arch=sm_21 -DBUG=0 -o bug0 bug.cu
nvcc -arch=sm_21 -DBUG=1 -o bug1 bug.cu
./bug0
./bug1
nvcc:NVIDIA (R) Cuda 编译器驱动程序 版权所有 (c) 2005-2012 NVIDIA Corporation 基于 Fri_Sep_21_17:28:58_PDT_2012 Cuda 编译工具,版本 5.0,V0.2.1221
nvcc 使用 g++-4.6 运行
最后是测试代码:
/* Compilation & run
   nvcc -arch=sm_21 -DBUG=0 -o bug0 bug.cu
   nvcc -arch=sm_21 -DBUG=1 -o bug1 bug.cu
   ./bug0
   ./bug1
 */
#include <stdio.h>
#include <assert.h>
__global__
void b(unsigned int *res)
{
#if BUG
    unsigned int na=threadIdx.x, nb=threadIdx.y, nc=threadIdx.z;
#else
    unsigned int na=0, nb=0, nc=0;
#endif
    assert( na==0 && nb==0 && nc==0 );
    printf("INITIAL VALUES: %u %u %u\n",na,nb,nc);
    unsigned int &iter=*res, na_max=2, nb_max=2, nc_max=2;
    iter=0;
    while(true)
    {
        printf("a-iter=%u     %u %u %u\n",iter,na,nb,nc);
        if( na>=na_max )
        {
            na  = 0;
            nb += blockDim.y;
            printf("b-iter=%u     %u %u %u\n",iter,na,nb,nc);
            if( nb>=nb_max )
            {
                printf("c-iter=%u     %u %u %u\n",iter,na,nb,nc);
                nb  = 0;
                nc += blockDim.z;
                if( nc>=nc_max )
                    break;  // end of loop
            }
            else
                printf("c-else\n");
        }
        else
            printf("b-else\n");
        printf("result    %u %u %u\n",na,nb,nc);
        iter++;
        na += blockDim.x;
    }
}
int main(void)
{
    unsigned int res, *d_res;
    cudaMalloc(&d_res,sizeof(unsigned int));
    b<<<1,1>>>(d_res);
    cudaMemcpy(&res, d_res, sizeof(unsigned int), cudaMemcpyDeviceToHost);
    cudaFree(d_res);
    printf("There are %u combinations (correct is 8)\n",res);
    return 0;
}