0

我正在使用 Selenium 使用Flask-Testing 的 LiveServer为我的 Flask Web 应用程序编写 unittest 。

运行测试时,我希望为所有测试使用一个浏览器,而不是为每个测试打开一个新的浏览器实例,因此我使用的是unittest 的 setUpClass

class TestApp(LiveServerTestCase, unittest.TestCase):
    def create_app(self):
        app = create_app()
        app.config['TESTING'] = True
        app.config.update(LIVESERVER_PORT=9898)
        return app

    @classmethod
    def setUpClass(cls):
        cls.chrome_browser = webdriver.Chrome()
        cls.chrome_browser.get(cls.get_server_url())

    def test_main_page(self):
        self.assertEqual(1, 1)

运行测试时,我得到以下信息:

TypeError: get_server_url() missing 1 required positional argument: 'self'

如何设置浏览器setUpClass

4

1 回答 1

0

您必须使用其开发人员设计的 Flask-Testing - 在__call__方法运行时启动 Flask 服务器和 Selenium 驱动程序...

或者您可以覆盖逻辑(在这种情况下,您在 setUpClass 上创建 selenium 驱动程序,并在每次测试运行时创建全新的 Flask 服务器)

import multiprocessing
import socket
import socketserver
import time
from urllib.parse import urlparse, urljoin

from flask import Flask
from flask_testing import LiveServerTestCase
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.wait import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC


class MyTest(LiveServerTestCase):

    @classmethod
    def setUpClass(cls) -> None:
        # Get the app
        cls.app = cls.create_app()

        cls._configured_port = cls.app.config.get('LIVESERVER_PORT', 5000)
        cls._port_value = multiprocessing.Value('i', cls._configured_port)

        # We need to create a context in order for extensions to catch up
        cls._ctx = cls.app.test_request_context()
        cls._ctx.push()

        cls.driver = webdriver.Firefox()

    @classmethod
    def tearDownClass(cls) -> None:
        cls._post_teardown()

    @classmethod
    def get_server_url(cls):
        """
        Return the url of the test server
        """
        return 'http://localhost:%s' % cls._port_value.value

    @classmethod
    def create_app(cls):
        app = Flask(__name__)

        @app.route('/')
        def hello_world():
            return 'Hello, World!'

        app.config['TESTING'] = True
        app.config['DEBUG'] = True
        app.config['ENV'] = "development"
        # Default port is 5000
        app.config['LIVESERVER_PORT'] = 8943
        # Default timeout is 5 seconds
        app.config['LIVESERVER_TIMEOUT'] = 10
        return app

    def __call__(self, *args, **kwargs):
        """
                Does the required setup, doing it here means you don't have to
                call super.setUp in subclasses.
                """
        try:
            self._spawn_live_server()
            super(LiveServerTestCase, self).__call__(*args, **kwargs)
        finally:
            self._terminate_live_server()

    @classmethod
    def _post_teardown(cls):
        if getattr(cls, '_ctx', None) is not None:
            cls._ctx.pop()
            del cls._ctx
        cls.driver.quit()

    @classmethod
    def _terminate_live_server(cls):
        if cls._process:
            cls._process.terminate()

    @classmethod
    def _spawn_live_server(cls):
        cls._process = None
        port_value = cls._port_value

        def worker(app, port):
            # Based on solution: http://stackoverflow.com/a/27598916
            # Monkey-patch the server_bind so we can determine the port bound by Flask.
            # This handles the case where the port specified is `0`, which means that
            # the OS chooses the port. This is the only known way (currently) of getting
            # the port out of Flask once we call `run`.
            original_socket_bind = socketserver.TCPServer.server_bind

            def socket_bind_wrapper(self):
                ret = original_socket_bind(self)

                # Get the port and save it into the port_value, so the parent process
                # can read it.
                (_, port) = self.socket.getsockname()
                port_value.value = port
                socketserver.TCPServer.server_bind = original_socket_bind
                return ret

            socketserver.TCPServer.server_bind = socket_bind_wrapper
            app.run(port=port, use_reloader=False)

        cls._process = multiprocessing.Process(
            target=worker, args=(cls.app, cls._configured_port)
        )

        cls._process.start()

        # We must wait for the server to start listening, but give up
        # after a specified maximum timeout
        timeout = cls.app.config.get('LIVESERVER_TIMEOUT', 5)
        start_time = time.time()

        while True:
            elapsed_time = (time.time() - start_time)
            if elapsed_time > timeout:
                raise RuntimeError(
                    "Failed to start the server after %d seconds. " % timeout
                )

            if cls._can_ping_server():
                break

    @classmethod
    def _get_server_address(cls):
        """
        Gets the server address used to test the connection with a socket.
        Respects both the LIVESERVER_PORT config value and overriding
        get_server_url()
        """
        parts = urlparse(cls.get_server_url())

        host = parts.hostname
        port = parts.port

        if port is None:
            if parts.scheme == 'http':
                port = 80
            elif parts.scheme == 'https':
                port = 443
            else:
                raise RuntimeError(
                    "Unsupported server url scheme: %s" % parts.scheme
                )

        return host, port

    @classmethod
    def _can_ping_server(cls):
        host, port = cls._get_server_address()
        if port == 0:
            # Port specified by the user was 0, and the OS has not yet assigned
            # the proper port.
            return False

        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            sock.connect((host, port))
        except socket.error as e:
            success = False
        else:
            success = True
        finally:
            sock.close()

        return success

    def test_main_page(self):
        self.driver.get(self.get_server_url())
        body = WebDriverWait(self.driver, 10).until(
            EC.visibility_of_element_located((By.TAG_NAME, "body"))
        )
        self.assertEqual(body.text, "Hello, World!")

    def test_main_page_once_more_time(self):
        self.driver.get(urljoin(self.get_server_url(), "some/wrong/path"))
        body = WebDriverWait(self.driver, 10).until(
            EC.visibility_of_element_located((By.TAG_NAME, "body"))
        )
        self.assertTrue(body.text.startswith("Not Found"))

于 2020-11-14T19:58:30.023 回答