8

I am using the unittest framework to automate integration tests of multi-threaded python code, external hardware and embedded C. Despite my blatant abuse of a unittesting framework for integration testing, it works really well. Except for one problem: I need the test to fail if an exception is raised from any of the spawned threads. Is this possible with the unittest framework?

A simple but non-workable solution would be to either a) refactor the code to avoid multi-threading or b) test each thread separately. I cannot do that because the code interacts asynchronously with the external hardware. I have also considered implementing some kind of message passing to forward the exceptions to the main unittest thread. This would require significant testing-related changes to the code being tested, and I want to avoid that.

Time for an example. Can I modify the test script below to fail on the exception raised in my_thread without modifying the x.ExceptionRaiser class?

import unittest
import x

class Test(unittest.TestCase):
    def test_x(self):
        my_thread = x.ExceptionRaiser()
        # Test case should fail when thread is started and raises
        # an exception.
        my_thread.start()
        my_thread.join()

if __name__ == '__main__':
    unittest.main()
4

3 回答 3

6

起初,sys.excepthook看起来像一个解决方案。它是一个全局钩子,每次抛出未捕获的异常时都会调用它。

不幸的是,这不起作用。为什么?很好地threading将您的函数包装run在代码中,该代码会打印您在屏幕上看到的可爱回溯(注意它总是如何告诉您Exception in thread {Name of your thread here}的?这就是它的完成方式)。

从 Python 3.8 开始,您可以重写一个函数来完成这项工作:threading.excepthook

... threading.excepthook() 可以被覆盖以控制如何处理 Thread.run() 引发的未捕获异常

那么我们该怎么办?用我们的逻辑替换这个函数,

对于 python >= 3.8

import traceback
import threading 
import os


class GlobalExceptionWatcher(object):
    def _store_excepthook(self, args):
        '''
        Uses as an exception handlers which stores any uncaught exceptions.
        '''
        self.__org_hook(args)
        formated_exc = traceback.format_exception(args.exc_type, args.exc_value, args.exc_traceback)
        self._exceptions.append('\n'.join(formated_exc))
        return formated_exc

    def __enter__(self):
        '''
        Register us to the hook.
        '''
        self._exceptions = []
        self.__org_hook = threading.excepthook
        threading.excepthook = self._store_excepthook

    def __exit__(self, type, value, traceback):
        '''
        Remove us from the hook, assure no exception were thrown.
        '''
        threading.excepthook = self.__org_hook
        if len(self._exceptions) != 0:
            tracebacks = os.linesep.join(self._exceptions)
            raise Exception(f'Exceptions in other threads: {tracebacks}')

对于旧版本的 Python,这有点复杂。长话短说,threading结节似乎有一个未记录的进口,其作用类似于:

threading._format_exc = traceback.format_exc

毫不奇怪,这个函数只有在线程函数抛出异常时才会被调用run

所以对于python <= 3.7

import threading 
import os

class GlobalExceptionWatcher(object):
    def _store_excepthook(self):
        '''
        Uses as an exception handlers which stores any uncaught exceptions.
        '''
        formated_exc = self.__org_hook()
        self._exceptions.append(formated_exc)
        return formated_exc
        
    def __enter__(self):
        '''
        Register us to the hook.
        '''
        self._exceptions = []
        self.__org_hook = threading._format_exc
        threading._format_exc = self._store_excepthook
        
    def __exit__(self, type, value, traceback):
        '''
        Remove us from the hook, assure no exception were thrown.
        '''
        threading._format_exc = self.__org_hook
        if len(self._exceptions) != 0:
            tracebacks = os.linesep.join(self._exceptions)
            raise Exception('Exceptions in other threads: %s' % tracebacks)

用法:

my_thread = x.ExceptionRaiser()
# will fail when thread is started and raises an exception.
with GlobalExceptionWatcher():
    my_thread.start()
    my_thread.join()
            

您仍然需要join自己,但在退出时,with 语句的上下文管理器将检查其他线程中抛出的任何异常,并适当地引发异常。


代码按“原样”提供,不提供任何形式的明示或暗示保证

这是一个无证的、可怕的黑客攻击。我在linux和windows上测试过,似乎可以。需要您自担风险使用它。

于 2012-09-19T13:37:43.237 回答
1

我自己也遇到过这个问题,我能想出的唯一解决方案是将 Thread 子类化以包含一个属性,以确定它是否在没有未捕获异常的情况下终止:

from threading import Thread

class ErrThread(Thread):
    """                                                                                                                                                                                               
    A subclass of Thread that will log store exceptions if the thread does                                                                                                                            
    not exit normally                                                                                                                                                                                 
    """
    def run(self):
        try:
            Thread.run(self)
        except Exception as self.err:
            pass
        else:
            self.err = None


class TaskQueue(object):
    """                                                                                                                                                                                               
    A utility class to run ErrThread objects in parallel and raises and exception                                                                                                                     
    in the event that *any* of them fail.                                                                                                                                                             
    """

    def __init__(self, *tasks):

        self.threads = []

        for t in tasks:
            try:
                self.threads.append(ErrThread(**t)) ## passing in a dict of target and args
            except TypeError:
                self.threads.append(ErrThread(target=t))

    def run(self):

        for t in self.threads:
            t.start()
        for t in self.threads:
            t.join()
            if t.err:
                raise Exception('Thread %s failed with error: %s' % (t.name, t.err))
于 2012-09-29T09:11:58.373 回答
1

我已经使用上面接受的答案一段时间了,但是从 Python 3.8 开始,该解决方案不再起作用,因为threading模块不再具有此_format_exc导入。

另一方面,该threading模块现在有一个很好的方法来注册自定义除了 Python 3.8 中的钩子,所以这是一个运行单元测试的简单解决方案,它断言在线程内引发了一些异常:

def test_in_thread():
    import threading

    exceptions_caught_in_threads = {}

    def custom_excepthook(args):
        thread_name = args.thread.name
        exceptions_caught_in_threads[thread_name] = {
            'thread': args.thread,
            'exception': {
                'type': args.exc_type,
                'value': args.exc_value,
                'traceback': args.exc_traceback
            }
        }

    # Registering our custom excepthook to catch the exception in the threads
    threading.excepthook = custom_excepthook

    # dummy function that raises an exception
    def my_function():
        raise Exception('My Exception')

    # running the funciton in a thread
    thread_1 = threading.Thread(name='thread_1', target=my_function, args=())

    thread_1.start()
    thread_1.join()

    assert 'thread_1' in exceptions_caught_in_threads  # there was an exception in thread 1
    assert exceptions_caught_in_threads['thread_1']['exception']['type'] == Exception
    assert str(exceptions_caught_in_threads['thread_1']['exception']['value']) == 'My Exception'
于 2020-10-20T16:18:44.070 回答