0

如何在恒定空间中向多个消费者提供迭代?

TLDR

编写一个在 CONSTANT SPACE 中通过以下测试的实现,同时将和min视为黑盒子。maxsum

def testit(implementation, N):
    assert implementation(range(N), min, max, sum) == (0, N-1, N*(N-1)//2)

讨论

我们喜欢迭代器,因为它们让我们可以懒惰地处理数据流,允许在 CONSTANT SPACE 中处理大量数据。

def source_summary(source, summary):
    return summary(source)

N = 10 ** 8
print(source_summary(range(N), min))
print(source_summary(range(N), max))
print(source_summary(range(N), sum))

每行需要几秒钟来执行,但使用的内存很少。但是,它确实需要对源进行 3 次单独的遍历。因此,如果您的来源是网络连接、数据采集硬件等,这将不起作用,除非您将所有数据缓存在某处,从而失去 CONSTANT SPACE 要求。

这是一个演示此问题的版本

def source_summaries(source, *summaries):
    from itertools import tee
    return tuple(map(source_summary, tee(source, len(summaries)),
                                     summaries))

testit(source_summaries, N)
print('OK')

测试通过,但tee必须保留所有数据的副本,因此空间使用量从 上升O(1)O(N).

如何在具有恒定内存的单次遍历中获得结果?

当然,有可能通过顶部给出的测试,使用O(1)空间,通过作弊:使用测试使用的特定迭代器消费者的知识。但这不是重点:source_summaries应该与任何迭代器消耗品一起使用,例如set, collections.Counter, ''.join,包括将来可能编写的任何和所有内容。实现必须将它们视为黑匣子。

需要明确的是:关于消费者的唯一可用知识是每个消费者消费一个迭代并返回一个结果。使用有关消费者的任何其他知识都是作弊。

想法

[编辑:我已经发布了这个想法的实现作为答案]

我可以想象一个使用的解决方案(我真的不喜欢)

  • 抢占式线程

  • 将消费者链接到源的自定义迭代器

让我们调用自定义迭代器link

  • 对于每个消费者,运行
result = consumer(<link instance for this thread>)
<link instance for this thread>.set_result(result)

在单独的线程上。

  • 在主线程上,类似于
for item in source:
    for l in links:
        l.push(item)

for l in links:
    l.stop()

for thread in threads:
    thread.join()

return tuple(link.get_result, links)
  • link.__next__阻塞直到link实例收到

    • .push(item)在这种情况下,它会返回项目
    • .stop()在这种情况下,它会引发StopIteration
  • 数据竞赛看起来像是一场噩梦。您需要一个推送队列,并且可能需要将一个哨兵对象放置在队列中link.stop()...以及我忽略的其他一些事情。

我更喜欢使用合作线程,但consumer(link)似乎不可避免地不合作。

你有什么不那么混乱的建议吗?

4

2 回答 2

1

这是您想法的另一种实现方式。它使用协作多线程。正如您所建议的,关键是使用多线程并让迭代器__next__方法阻塞,直到所有线程都消耗了当前迭代。

此外,迭代器包含一个(可选的)恒定大小的缓冲区。有了这个缓冲区,我们可以分块读取源代码并避免大量锁定/同步。

我的实现还处理了一些消费者在到达迭代器末尾之前停止迭代的情况。

import threading

class BufferedMultiIter:
    def __init__(self, source, n, bufsize = 1):
        '''`source` is an iterator or iterable,
        `n` is the number of threads that will interact with this iterator,
        `bufsize` is the size of the internal buffer. The iterator will read
        and buffer elements from `source` in chunks of `bufsize`. The bigger
        the buffer is, the better the performance but also the bigger the
        (constant) space requirement.
        '''
        self._source = iter(source)
        self._n = n
        # Condition variable for synchronization
        self._cond = threading.Condition()
        # Buffered values
        bufsize = max(bufsize, 1)
        self._buffer = [None] * bufsize
        self._buffered = 0
        self._next = threading.local()
        # State variables to implement the "wait for buffer to get refilled"
        # protocol
        self._serial = 0
        self._waiting = 0

        # True if we reached the end of the source
        self._stop = False
        # Was the thread killed (for error handling)?
        self._killed = False

    def _fill_buffer(self):
        '''Refill the internal buffer.'''
        self._buffered = 0
        while self._buffered < len(self._buffer):
            try:
                self._buffer[self._buffered] = next(self._source)
                self._buffered += 1
            except StopIteration:
                self._stop = True
                break
            # Explicitly clear the unused part of the buffer to release
            # references as early as possible
            for i in range(self._buffered, len(self._buffer)):
                self._buffer[i] = None
        self._waiting = 0
        self._serial += 1

    def register_thread(self):
        '''Register a thread.

        Each thread that wants to access this iterator must first register
        with the iterator. It is an error to register the same thread more
        than once. It is an error to access this iterator with a thread that
        was not registered (with the exception of calling `kill`). It is an
        error to register more threads than the number that was passed to the
        constructor.
        '''
        self._next.i = 0

    def unregister_thread(self):
        '''Unregister a thread from this iterator.

        This should be called when a thread is done using the iterator.
        It catches the case in which a consumer does not consume all the
        elements from the iterator but exits early.
        '''
        assert hasattr(self._next, 'i')
        delattr(self._next, 'i')
        with self._cond:
            assert self._n > 0
            self._n -= 1
            if self._waiting == self._n:
                self._fill_buffer()
            self._cond.notify_all()

    def kill(self):
        '''Forcibly kill this iterator.

        This will wake up all threads currently blocked in `__next__` and
        will have them raise a `StopIteration`.
        This function should be called in case of error to terminate all
        threads as fast as possible.
        '''
        self._cond.acquire()
        self._killed = True
        self._stop = True
        self._cond.notify_all()
        self._cond.release()
    def __iter__(self): return self
    def __next__(self):
        if self._next.i == self._buffered:
            # We read everything from the buffer.
            # Wait until all other threads have also consumed the buffer
            # completely and then refill it.
            with self._cond:
                old = self._serial
                self._waiting += 1
                if self._waiting == self._n:
                    self._fill_buffer()
                    self._cond.notify_all()
                else:
                    # Wait until the serial number changes. A change in
                    # serial number indicates that another thread has filled
                    # the buffer
                    while self._serial == old and not self._killed:
                        self._cond.wait()
            # Start at beginning of newly filled buffer
            self._next.i = 0

        if self._killed:
            raise StopIteration
        k = self._next.i
        if k == self._buffered and self._stop:
            raise StopIteration
        value = self._buffer[k]
        self._next.i = k + 1
        return value

class NotAll:
    '''A consumer that does not consume all the elements from the source.'''
    def __init__(self, limit):
        self._limit = limit
        self._consumed = 0
    def __call__(self, it):
        last = None
        for k in it:
            last = k
            self._consumed += 1
            if self._consumed >= self._limit:
                break
        return last

def multi_iter(iterable, *consumers, **kwargs):
    '''Iterate using multiple consumers.

    Each value in `iterable` is presented to each of the `consumers`.
    The function returns a tuple with the results of all `consumers`.

    There is an optional `bufsize` argument. This controls the internal
    buffer size. The bigger the buffer, the better the performance, but also
    the bigger the (constant) space requirement of the operation.

    NOTE: This will spawn a new thread for each consumer! The iteration is
    multi-threaded and happens in parallel for each element.
    '''
    n = len(consumers)
    it = BufferedMultiIter(iterable, n, kwargs.get('bufsize', 1))
    threads = list() # List with **running** threads
    result = [None] * n
    def thread_func(i, c):
        it.register_thread()
        result[i] = c(it)
        it.unregister_thread()
    try:
        for c in consumers:
            t = threading.Thread(target = thread_func, args = (len(threads), c))
            t.start()
            threads.append(t)
    except:
        # Here we should forcibly kill all the threads but there is not
        # t.kill() function or similar. So the best we can do is stop the
        # iterator
        it.kill()
    finally:
        while len(threads) > 0:
            t = threads.pop(-1)
            t.join()
    return tuple(result)

from time import time
N = 10 ** 7
notall1 = NotAll(1)
notall1000 = NotAll(1000)
start1 = time()
res1 = (min(range(N)), max(range(N)), sum(range(N)), NotAll(1)(range(N)),
        NotAll(1000)(range(N)))
stop1 = time()
print('5 iterators: %s %.2f' % (str(res1), stop1 - start1))

for p in range(5):
    start2 = time()
    res2 = multi_iter(range(N), min, max, sum, NotAll(1), NotAll(1000),
                      bufsize = 2**p)
    stop2 = time()
    print('multi_iter%d: %s %.2f' % (p, str(res2), stop2 - start2))

时间再次很糟糕,但您可以看到使用恒定大小的缓冲区如何显着改善事情:

5 iterators: (0, 9999999, 49999995000000, 0, 999) 0.71
multi_iter0: (0, 9999999, 49999995000000, 0, 999) 342.36
multi_iter1: (0, 9999999, 49999995000000, 0, 999) 264.71
multi_iter2: (0, 9999999, 49999995000000, 0, 999) 151.06
multi_iter3: (0, 9999999, 49999995000000, 0, 999) 95.79
multi_iter4: (0, 9999999, 49999995000000, 0, 999) 72.79

也许这可以作为良好实施的想法来源。

于 2020-04-09T10:56:13.723 回答
0

这是原始问题中概述的抢先线程解决方案的实现。

[编辑:这个实现有一个严重的问题。[编辑,现已修复,使用受 Daniel Junglas 启发的解决方案。]

不遍历整个可迭代对象的消费者将导致队列内部的空间泄漏Link。例如:


def exceeds_10(iterable):
    for item in iterable:
        if item > 10:
            return True
    return False

如果您将其用作消费者之一并使用源range(10**6),它将Link在前 11 个项目之后停止从队列中移除项目,留下大约10**6项目在队列中累积!

]


class Link:

    def __init__(self, queue):
        self.queue = queue

    def __iter__(self):
        return self

    def __next__(self):
        item = self.queue.get()
        if item is FINISHED:
            raise StopIteration
        return item

    def put(self, item):
        self.queue.put(item)

    def stop(self):
        self.queue.put(FINISHED)

    def consumer_not_listening_any_more(self):
        self.__class__ = ClosedLink


class ClosedLink:

    def put(self, _): pass
    def stop(self)  : pass


class FINISHED: pass


def make_thread(link, consumer, future):
    from threading import Thread
    return Thread(target = lambda: on_thread(link, consumer, future))

def on_thread(link, consumer, future):
    future.set_result(consumer(link))
    link.consumer_not_listening_any_more()

def source_summaries_PREEMPTIVE_THREAD(source, *consumers):
    from queue     import SimpleQueue as Queue
    from asyncio   import Future

    links   = tuple(Link(Queue()) for _ in consumers)
    futures = tuple(     Future() for _ in consumers)
    threads = tuple(map(make_thread, links, consumers, futures))

    for thread in threads:
        thread.start()

    for item in source:
        for link in links:
            link.put(item)

    for link in links:
        link.stop()

    for t in threads:
        t.join()

    return tuple(f.result() for f in futures)

它有效,但(不出所料)性能严重下降:

def time(thunk):
    from time import time
    start = time()
    thunk()
    stop  = time()
    return stop - start

N = 10 ** 7
t = time(lambda: testit(source_summaries, N))
print(f'old: {N} in {t:5.1f} s')

t = time(lambda: testit(source_summaries_PREEMPTIVE_THREAD, N))
print(f'new: {N} in {t:5.1f} s')

给予

old: 10000000 in   1.2 s
new: 10000000 in  30.1 s

因此,即使这是一个理论上的解决方案,它也不是一个实际的解决方案[*]。

因此,我认为这种方法是一条死胡同,除非有办法说服consumer合作让步(而不是强迫它先发制人地让步)

def on_thread(link, consumer, future):
    future.set_result(consumer(link))

...但这似乎根本不可能。很想被证明是错误的。

[*] 这实际上有点苛刻:测试对琐碎的数据完全没有任何作用;如果这是对元素执行大量计算的大型计算的一部分,那么这种方法可能真的很有用。

于 2020-04-08T14:22:10.753 回答