9

我有一个原子计数器 ( std::atomic<uint32_t> count),它将顺序递增的值发送​​给多个线程。

uint32_t my_val = ++count;

在我得到之前,my_val我想确保增量不会溢出(即:回到 0)

if (count == std::numeric_limits<uint32_t>::max())
    throw std::runtime_error("count overflow");

我认为这是一个幼稚的检查,因为如果检查是由两个线程在任一递增计数器之前执行的,则要递增的第二个线程将返回 0

if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
    throw std::runtime_error("count overflow");
uint32_t my_val = ++count;       // before either gets here - possible overflow

因此,我想我需要使用一个CAS操作来确保当我增加我的计数器时,我确实在防止可能的溢出。

所以我的问题是:

  • 我的实现是否正确?
  • 它是否尽可能高效(特别是我需要检查max两次)?

我的代码(带有工作示例)如下:

#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>

std::atomic<uint16_t> count;

uint16_t get_val() // called by multiple threads
{
    uint16_t my_val;
    do
    {
        my_val = count;

        // make sure I get the next value

        if (count.compare_exchange_strong(my_val, my_val + 1))
        {
            // if I got the next value, make sure we don't overflow

            if (my_val == std::numeric_limits<uint16_t>::max())
            {
                count = std::numeric_limits<uint16_t>::max() - 1;
                throw std::runtime_error("count overflow");
            }
            break;
        }

        // if I didn't then check if there are still numbers available

        if (my_val == std::numeric_limits<uint16_t>::max())
        {
            count = std::numeric_limits<uint16_t>::max() - 1;
            throw std::runtime_error("count overflow");
        }

        // there are still numbers available, so try again
    }
    while (1);
    return my_val + 1;
}

void run()
try
{
    while (1)
    {
        if (get_val() == 0)
            exit(1);
    }

}
catch(const std::runtime_error& e)
{
    // overflow
}

int main()
{
    while (1)
    {
        count = 1;
        std::thread a(run);
        std::thread b(run);
        std::thread c(run);
        std::thread d(run);
        a.join();
        b.join();
        c.join();
        d.join();
        std::cout << ".";
    }
    return 0;
}
4

3 回答 3

8

是的,您需要使用CAS操作。

std::atomic<uint16_t> g_count;

uint16_t get_next() {
   uint16_t new_val = 0;
   do {
      uint16_t cur_val = g_count;                                            // 1
      if (cur_val == std::numeric_limits<uint16_t>::max()) {                 // 2
          throw std::runtime_error("count overflow");
      }
      new_val = cur_val + 1;                                                 // 3
   } while(!std::atomic_compare_exchange_weak(&g_count, &cur_val, new_val)); // 4

   return new_val;
}

这个想法如下:一旦g_count == std::numeric_limits<uint16_t>::max()get_next()函数总是会抛出异常。

脚步:

  1. 获取计数器的当前值
  2. 如果它是最大的,抛出一个异常(不再可用的数字)
  3. 获取新值作为当前值的增量
  4. 尝试以原子方式设置新值。如果我们未能设置它(它已经由另一个线程完成),请再试一次。
于 2013-05-28T06:37:30.540 回答
2

如果效率是一个大问题,那么我建议不要对检查那么严格。我猜在正常使用下溢出不会成为问题,但你真的需要完整的 65K 范围(你的示例使用 uint16)吗?

如果您假设您正在运行的线程数有一个最大值,那会更容易。这是一个合理的限制,因为没有程序具有无限数量的并发。因此,如果您有N线程,您可以简单地将溢出限制减少到65K - N. 要比较是否溢出,您不需要 CAS:

uint16_t current = count.load(std::memory_order_relaxed);
if( current >= (std::numeric_limits<uint16_t>::max() - num_threads - 1) )
    throw std::runtime_error("count overflow");
count.fetch_add(1,std::memory_order_relaxed);

这会产生软溢出条件。如果两个线程同时来到这里,它们都可能会通过,但这没关系,因为 count 变量本身永远不会溢出。此时任何未来的到达都会在逻辑上溢出(直到再次减少计数)。

于 2013-05-28T06:36:37.570 回答
1

在我看来,仍然存在一个竞争条件,count它将暂时设置为 0,以便另一个线程将看到 0 值。

假设count是 atstd::numeric_limits<uint16_t>::max()并且两个线程尝试获取增加的值。在 Thread 1 执行的那一刻count.compare_exchange_strong(my_val, my_val + 1),count 设置为 0,这就是 Thread 2 将看到的,如果它get_val()在 Thread 1 有机会恢复count到之前调用并完成max()

于 2013-05-28T06:35:12.143 回答