Currently I'm working with Metal compute shaders and trying to understand how GPU threads synchronization works there.
I wrote a simple code but it doesn't work the way I expect it:
Consider I have threadgroup
variable, which is array where all threads can produce an output simultaneously.
kernel void compute_features(device float output [[ buffer(0) ]],
ushort2 group_pos [[ threadgroup_position_in_grid ]],
ushort2 thread_pos [[ thread_position_in_threadgroup]],
ushort tid [[ thread_index_in_threadgroup ]])
{
threadgroup short blockIndices[288];
float someValue = 0.0
// doing some work here which fills someValue...
blockIndices[thread_pos.y * THREAD_COUNT_X + thread_pos.x] = someValue;
//wait when all threads are done with calculations
threadgroup_barrier(mem_flags::mem_none);
output += blockIndices[thread_pos.y * THREAD_COUNT_X + thread_pos.x]; // filling out output variable with threads calculations
}
The code above doesn't work. Output variable doesn't contain all threads calculations, it contains only the value from the thread which was presumable the last at adding up a value to output
. To me it seems like threadgroup_barrier
does absolutely nothing.
Now, the interesting part. The code below works:
blockIndices[thread_pos.y * THREAD_COUNT_X + thread_pos.x] = someValue;
threadgroup_barrier(mem_flags::mem_none); //wait when all threads are done with calculations
if (tid == 0) {
for (int i = 0; i < 288; i ++) {
output += blockIndices[i]; // filling out output variable with threads calculations
}
}
And this code also works as good as the previous one:
blockIndices[thread_pos.y * THREAD_COUNT_X + thread_pos.x] = someValue;
if (tid == 0) {
for (int i = 0; i < 288; i ++) {
output += blockIndices[i]; // filling out output variable with threads calculations
}
}
To summarize: My code works as expected only when I'm handling threadgroup memory in one GPU thread, no matter what's the id of it, it can be the last thread in the threadgroup as well as the first one. And presense of threadgroup_barrier
makes absolutely no difference. I also used threadgroup_barrier
with mem_threadgroup
flag, code still doesn't work.
I understand that I might be missing some very important detail and I would be happy if someone can point me out to my errors. Thanks in advance!