4

我正在 Python 中实现 Trampoline,以便编写具有堆栈安全性的递归函数(因为 CPython 不具有 TCO)。它看起来像这样:

from typing import Generic, TypeVar
from abc import ABC, abstractmethod

A = TypeVar('A', covariant=True)


class Trampoline(Generic[A], ABC):
    """
    Base class for Trampolines. Useful for writing stack safe-safe
    recursive functions.
    """
    @abstractmethod
    def _resume(self) -> 'Trampoline[A]':
        """
        Let this trampoline resume the interpreter loop
        """
        pass

    @abstractmethod
    def _handle_cont(
        self, cont: Callable[[A], 'Trampoline[B]']
    ) -> 'Trampoline[B]':
        """
        Handle continuation function passed to `and_then`
        """
        pass

    @property
    def _is_done(self) -> bool:
        return isinstance(self, Done)

    def and_then(self, f: Callable[[A], 'Trampoline[B]']) -> 'Trampoline[B]':
        """
        Apply ``f`` to the value wrapped by this trampoline.

        Args:
            f: function to apply the value in this trampoline
        Return:
            Result of applying ``f`` to the value wrapped by \
            this trampoline
        """
        return AndThen(self, f)

    def map(self, f: Callable[[A], B]) -> 'Trampoline[B]':
        """
        Map ``f`` over the value wrapped by this trampoline.

        Args:
            f: function to wrap over this trampoline
        Return:
            new trampoline wrapping the result of ``f``
        """
        return self.and_then(lambda a: Done(f(a)))

    def run(self) -> A:
        """
        Interpret a structure of trampolines to produce a result

        Return:
            result of intepreting this structure of \
            trampolines
        """
        trampoline = self
        while not trampoline._is_done:
            trampoline = trampoline._resume()

        return cast(Done[A], trampoline).a


class Done(Trampoline[A]):
    """
    Represents the result of a recursive computation.
    """
    a: A

    def _resume(self) -> Trampoline[A]:
        return self

    def _handle_cont(self,
                     cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
        return cont(self.a)


class Call(Trampoline[A]):
    """
    Represents a recursive call.
    """
    thunk: Callable[[], Trampoline[A]]

    def _handle_cont(self,
                     cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
        return self.thunk().and_then(cont)  # type: ignore

    def _resume(self) -> Trampoline[A]:
        return self.thunk()  # type: ignore


class AndThen(Generic[A, B], Trampoline[B]):
    """
    Represents monadic bind for trampolines as a class to avoid
    deep recursive calls to ``Trampoline.run`` during interpretation.
    """
    sub: Trampoline[A]
    cont: Callable[[A], Trampoline[B]]

    def _handle_cont(self,
                     cont: Callable[[B], Trampoline[C]]) -> Trampoline[C]:
        return self.sub.and_then(self.cont).and_then(cont)  # type: ignore

    def _resume(self) -> Trampoline[B]:
        return self.sub._handle_cont(self.cont)  # type: ignore

    def and_then(  # type: ignore
        self, f: Callable[[A], Trampoline[B]]
    ) -> Trampoline[B]:
        return AndThen(
            self.sub,
            lambda x: Call(lambda: self.cont(x).and_then(f))  # type: ignore
        )

现在,我需要一个单子序列运算符。我最初的看法是这样的:

from typing import Iterable

from functools import reduce


def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
    def combine(result: Trampoline[Iterable[A]], ta: Trampoline[A]) -> Trampoline[Iterable[A]]:
        return result.and_then(lambda as_: ta.map(lambda a: as_ + (a,)))

    return reduce(combine, iterable, Done(()))

这行得通,但是以这种方式减少一长串蹦床所导致的所有函数调用的开销绝对会降低性能。

所以我尝试了这个:

def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
    def thunk() -> Trampoline[Iterable[A]]:
        return Done(tuple([t.run() for t in iterable]))
    
    return Call(thunk)

现在,我的直觉是,第二个解决方案sequence不是堆栈安全的,因为它调用的是run,这意味着run它将run在解释期间调用(通过Call.thunk但不少于)。但是,无论我如何混搭,我似乎都无法产生堆栈溢出。

例如,我认为应该这样做:

t, *ts = [sequence(Done(v) for v in range(2)) for _ in range(10000)]

def combine(t1, t2):
    return t1.and_then(lambda _: t2)

final = reduce(combine, ts, t)
final.run()  # My gut feeling says this should overflow the stack, but it doesn't

我尝试了无数其他示例,但没有堆栈溢出。我的直觉仍然认为这是行不通的。

我需要有人说服我以这种方式蹦床解释器循环实际上是堆栈安全的,或者向我展示一个溢出堆栈的示例

4

1 回答 1

0

您需要在解释期间导致堆栈溢出的递归:

sequence([sequence([sequence([sequence([...
于 2020-08-12T03:40:52.503 回答