3

我有一组相关的类,它们都继承自一个基类。我想使用工厂方法来实例化这些类的对象。我想这样做是因为我可以在将对象返回给调用者之前将对象存储在由类名键入的字典中。然后,如果有对特定类的对象的请求,我可以检查一个对象是否已经存在于我的字典中。如果没有,我将实例化它并将其添加到字典中。如果是这样,那么我将从字典中返回现有对象。这实际上会将我模块中的所有类变成单例。

我想这样做是因为它们都继承自的基类对子类中的函数进行了一些自动包装,而且我不希望函数被多次包装,如果两个对象当前会发生这种情况创建相同的类。

我能想到的唯一方法是检查__init__()基类方法中的堆栈跟踪,该方法将始终被调用,如果堆栈跟踪未显示创建对象的请求来自工厂功能。

这是一个好主意吗?

编辑:这是我的基类的源代码。有人告诉我,我需要找出元类来更优雅地完成这项工作,但这就是我现在所拥有的。所有 Page 对象都使用相同的 Selenium Webdriver 实例,该实例位于顶部导入的驱动程序模块中。这个驱动程序的初始化成本很高——它在第一次创建 LoginPage 时被初始化。初始化后,该initialize()方法将返回现有驱动程序而不是创建新驱动程序。这个想法是用户必须从创建一个 LoginPage 开始。最终将定义数十个 Page 类,它们将被单元测试代码用于验证网站的行为是否正确。

from driver import get_driver, urlpath, initialize
from settings import urlpaths

class DriverPageMismatchException(Exception):
    pass

class URLVerifyingPage(object):
    # we add logic in __init__() to check the expected urlpath for the page
    # against the urlpath that the driver is showing - we only want the page's
    # methods to be invokable if the driver is actualy at the appropriate page.
    # If the driver shows a different urlpath than the page is supposed to
    # have, the method should throw a DriverPageMismatchException

    def __init__(self):
        self.driver = get_driver()
        self._adjust_methods(self.__class__)

    def _adjust_methods(self, cls):
        for attr, val in cls.__dict__.iteritems():
            if callable(val) and not attr.startswith("_"):
                print "adjusting:"+str(attr)+" - "+str(val)
                setattr(
                    cls,
                    attr,
                    self._add_wrapper_to_confirm_page_matches_driver(val)
                )
        for base in cls.__bases__:
            if base.__name__ == 'URLVerifyingPage': break
            self._adjust_methods(base)

    def _add_wrapper_to_confirm_page_matches_driver(self, page_method):
        def _wrapper(self, *args, **kwargs):
            if urlpath() != urlpaths[self.__class__.__name__]:
                raise DriverPageMismatchException(
                    "path is '"+urlpath()+
                    "' but '"+urlpaths[self.__class.__name__]+"' expected "+
                    "for "+self.__class.__name__
                )
            return page_method(self, *args, **kwargs)
        return _wrapper


class LoginPage(URLVerifyingPage):
    def __init__(self, username=username, password=password, baseurl="http://example.com/"):
        self.username = username
        self.password = password
        self.driver = initialize(baseurl)
        super(LoginPage, self).__init__()

    def login(self):
        driver.find_element_by_id("username").clear()
        driver.find_element_by_id("username").send_keys(self.username)
        driver.find_element_by_id("password").clear()
        driver.find_element_by_id("password").send_keys(self.password)
        driver.find_element_by_id("login_button").click()
        return HomePage()

class HomePage(URLVerifyingPage):
    def some_method(self):
        ...
        return SomePage()

    def many_more_methods(self):
        ...
        return ManyMorePages()

如果一个页面被实例化了几次,这没什么大不了的——这些方法只会被包装几次,并且会发生一些不必要的检查,但一切仍然有效。但是如果一个页面被实例化几十、几百或几万次,那就不好了。我可以在每个页面的类定义中放置一个标志并检查方法是否已经被包装,但我喜欢保持类定义纯净和干净并将所有 hocus-pocus 推入深角落的想法我的系统没有人可以看到它,它只是工作。

4

4 回答 4

3

在 Python 中,几乎不值得尝试“强制”任何东西。无论你想出什么,有人可以通过猴子修补你的课程、复制和编辑源代码、玩弄字节码等来绕过它。

因此,只需编写您的工厂,并将其记录为获取类实例的正确方法,并期望任何使用您的类编写代码的人都能理解 TOOWTDI,除非她真的知道自己在做什么并且愿意,否则不要违反它找出并处理后果。

如果您只是想防止事故发生,而不是故意“误用”,那就另当别论了。事实上,这只是标准的合同设计:检查不变量。当然此时,SillyBaseClass已经搞砸了,再修也来不及了,你能做的只有assert, raise, log, 或者其他合适的。但这就是您想要的:这是应用程序中的逻辑错误,唯一要做的就是让程序员修复它,这assert可能正是您想要的。

所以:

class SillyBaseClass:
    singletons = {}

class Foo(SillyBaseClass):
    def __init__(self):
        assert self.__class__ not in SillyBaseClass.singletons

def get_foo():
    if Foo not in SillyBaseClass.singletons:
        SillyBaseClass.singletons[Foo] = Foo()
    return SillyBaseClass.singletons[Foo]

如果你真的想阻止事情走到这一步,你可以在方法中更早地检查不变量__new__,但除非“SillyBaseClass搞砸了”等同于“发射核武器”,否则何必呢?

于 2013-01-01T06:42:14.990 回答
2

听起来您想提供一个__new__实现:类似于:

class MySingledtonBase(object):
    instance_cache = {}
    def __new__(cls, arg1, arg2):
        if cls in MySingletonBase.instance_cache:
            return MySingletonBase.instance_cache[cls]
        self = super(MySingletonBase, cls).__new__(arg1, arg2)
        MySingletonBase.instance_cache[cls] = self
        return self
于 2013-01-01T06:32:46.320 回答
2

与其添加复杂的代码来在运行时捕获错误,我会首先尝试使用约定来指导模块的用户自己做正确的事情。

给你的类“私有”名称(以下划线为前缀),给它们命名暗示它们不应该被实例化(例如_Internal ...)并使你的工厂函数“公共”。

也就是说,像这样:

class _InternalSubClassOne(_BaseClass):
    ...

class _InternalSubClassTwo(_BaseClass):
    ...

# An example factory function.
def new_object(arg):
    return _InternalSubClassOne() if arg == 'one' else _InternalSubClassTwo()

我还会为每个类添加文档字符串或注释,例如“不要手动实例化这个类,使用工厂方法new_object。”

于 2013-01-01T09:00:28.563 回答
0

您也可以在工厂方法中嵌套类,如下所述: https ://python-3-patterns-idioms-test.readthedocs.io/en/latest/Factory.html#preventing-direct-creation

来自上述来源的工作示例:

# Factory/shapefact1/NestedShapeFactory.py
import random

class Shape(object):
    types = []

def factory(type):
    class Circle(Shape):
        def draw(self): print("Circle.draw")
        def erase(self): print("Circle.erase")

    class Square(Shape):
        def draw(self): print("Square.draw")
        def erase(self): print("Square.erase")

    if type == "Circle": return Circle()
    if type == "Square": return Square()
    assert 0, "Bad shape creation: " + type

def shapeNameGen(n):
    for i in range(n):
        yield factory(random.choice(["Circle", "Square"]))

# Circle() # Not defined

for shape in shapeNameGen(7):
    shape.draw()
    shape.erase()

我不喜欢这个解决方案,只是想将其添加为一个选项。

于 2021-03-29T09:05:17.967 回答