我有一个原子计数器 ( std::atomic<uint32_t> count
),它将顺序递增的值发送给多个线程。
uint32_t my_val = ++count;
在我得到之前,my_val
我想确保增量不会溢出(即:回到 0)
if (count == std::numeric_limits<uint32_t>::max())
throw std::runtime_error("count overflow");
我认为这是一个幼稚的检查,因为如果检查是由两个线程在任一递增计数器之前执行的,则要递增的第二个线程将返回 0
if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
throw std::runtime_error("count overflow");
uint32_t my_val = ++count; // before either gets here - possible overflow
因此,我想我需要使用一个CAS
操作来确保当我增加我的计数器时,我确实在防止可能的溢出。
所以我的问题是:
- 我的实现是否正确?
- 它是否尽可能高效(特别是我需要检查
max
两次)?
我的代码(带有工作示例)如下:
#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>
std::atomic<uint16_t> count;
uint16_t get_val() // called by multiple threads
{
uint16_t my_val;
do
{
my_val = count;
// make sure I get the next value
if (count.compare_exchange_strong(my_val, my_val + 1))
{
// if I got the next value, make sure we don't overflow
if (my_val == std::numeric_limits<uint16_t>::max())
{
count = std::numeric_limits<uint16_t>::max() - 1;
throw std::runtime_error("count overflow");
}
break;
}
// if I didn't then check if there are still numbers available
if (my_val == std::numeric_limits<uint16_t>::max())
{
count = std::numeric_limits<uint16_t>::max() - 1;
throw std::runtime_error("count overflow");
}
// there are still numbers available, so try again
}
while (1);
return my_val + 1;
}
void run()
try
{
while (1)
{
if (get_val() == 0)
exit(1);
}
}
catch(const std::runtime_error& e)
{
// overflow
}
int main()
{
while (1)
{
count = 1;
std::thread a(run);
std::thread b(run);
std::thread c(run);
std::thread d(run);
a.join();
b.join();
c.join();
d.join();
std::cout << ".";
}
return 0;
}