1

我尝试为 cps 版本的斐波那契函数实现蹦床。但我不能让它快速(添加缓存)并支持mutual_recursion。

实现代码:

import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable

START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3


@dataclass
class CTX:
    kind: int
    result: Any    # TODO ......
    f: Callable
    args: Optional[list]
    kwargs: Optional[dict]


def trampoline(f):
    ctx = CTX(START, None, None, None, None)

    @functools.wraps(f)
    def decorator(*args, **kwargs):
        nonlocal ctx
        if ctx.kind in (CONTINUE, CONTINUE_END):
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE
            return
        elif ctx.kind == START:
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE

        result = None
        while ctx.kind != RETURN:
            args = ctx.args
            kwargs = ctx.kwargs
            result = f(*args, **kwargs)
            if ctx.kind == CONTINUE_END:
                ctx.kind = RETURN
            else:
                ctx.kind = CONTINUE_END

        return result

    return decorator

这是可运行的示例。

@functools.lru_cache
def fib(n):
    if n == 0:
        return 1
    elif n == 1:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)

@trampoline
def fib_cps(n, k):
    if n == 0:
        return k(1)
    elif n == 1:
        return k(1)
    else:
        return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))

def fib_cps_wrapper(n):
    return fib_cps(n, lambda i:i)


@trampoline
def fib_tail(n, acc1=1, acc2=1):
    if n < 2:
        return acc1
    else:
        return fib_tail(n - 1, acc1 + acc2, acc1)


if __name__ == "__main__":
    print(fib(100))
    print(fib_tail(10000))
    print(fib_cps_wrapper(40))

跑号太慢了40。当更大时,超过fib了得到的最大递归深度。n但是添加后lru_cache会很快。iter trampolined 版本可以用于递归深度并且运行速度非常快。

这是其他一些人的工作:

  1. 支持cps版本缓存: https ://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
  2. 支持mutual_recursion:https ://github.com/0x65/trampoline但它太难理解了。
4

1 回答 1

1

查看您共享的链接,有很多有趣的解决方案。我特别受此启发并改变了一些事情。回顾一下,您需要一个尾递归装饰器,它既可以缓存函数先前执行的结果,又支持相互递归(?)。还有另一个关于尾递归上下文中相互递归的有趣讨论,它可能会帮助您理解主要问题。


我已经编写了一个同时进行缓存和相互递归的装饰器:我认为它可以进一步简化/改进,但它适用于我选择的测试样本:

from collections import namedtuple
import functools

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
    f._first_call = True
    f._cache = {}

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        if f._first_call:
            f._new_args = args
            f._new_kwargs = kwargs
        
            try:
                f._first_call = False
                while True:
                    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
                    if cache_key in f._cache:
                        return f._cache[cache_key]

                    result = f(*f._new_args, **f._new_kwargs)

                    if not isinstance(result, TailRecArguments):
                        f._cache[cache_key] = result

                    if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                        f._new_args = result.args
                        f._new_kwargs = result.kwargs
                    else:
                        break

                return result
            finally:
                f._first_call = True
        else:
            return TailRecArguments(f, args, kwargs)

    return wrapper

乍一看似乎很复杂,但它重用了链接中讨论的一些概念。


初始化

f._first_call = True
f._cache = {}

而不是像START,CONTINUE和那样的状态RETURN,在这种情况下,我只需要区分 the_first_call和以下状态。事实上,第一次调用函数后,下一次调用会返回一个TailRecArgument存储参数的函数。

f._cache是该特定功能的缓存。


尾递归

if f._first_call:
    f._new_args = args
    f._new_kwargs = kwargs

    try:
        f._first_call = False
        while True:
            result = f(*f._new_args, **f._new_kwargs)

            if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                f._new_args = result.args
                f._new_kwargs = result.kwargs
            else:
                break

        return result
    finally:
        f._first_call = True
else:
    return TailRecArguments(f, args, kwargs)

这个版本的尾递归如何工作?在while循环中,在第一次调用装饰函数后,使用返回的新参数连续调用该函数。

我什么时候可以退出循环?一旦返回的值不是 type TailRecArguments,这意味着最后一次函数调用没有递归调用自身,而是返回了一个实际值。在这种情况下,我只需要返回结果并设置f._first_call = True。不幸的是,它比这复杂一点,因为它不适用于相互递归。这里的解决方法是存储TailRecArguments甚至调用的函数。通过这种方式,我可以检查用于下一个循环的参数是用于同一个函数(result.wrapped_func == f)还是另一个尾递归函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,而是我可以返回它们,因为它们肯定会在while遇到的第一个尾递归函数的循环中执行。唯一的缺点f._first_call每次参数属于另一个函数时都会重置。


缓存

while True:
    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
    if cache_key in f._cache:
        return f._cache[cache_key]

    result = f(*f._new_args, **f._new_kwargs)

    if not isinstance(result, TailRecArguments):
        f._cache[cache_key] = result

在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我把它放在while循环中。不可能,因为只有在 while 循环内,该函数才会被连续调用,我可以检查缓存命中。

cache_key因为我使用了functools模块的内部函数,所以我在创作上作弊了一点。它是@cache同一个模块中的装饰器使用的,您可以使用

import inspect
import functools
print(inspect.getsource(functools._make_key))

还有其他方法可以从创建缓存键*args**kwargs就像这个一样,它再次指向_make_key. 为了让你的代码更稳定,当然要避免使用私有成员。

正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): .... 我想缓存值,而不是尾递归调用的参数。

(实际上,当递归调用返回实际值时,我认为您可以将所有内容临时存储TailRecArguments在一个列表中,并在缓存中添加与该列表大小一样多的条目。这会使解决方案复杂化,但如果您仍然可以接受有性能问题。这可能会在相互递归的情况下引发一些错误,如果需要,我将继续处理)。


测试

这些是我用来测试装饰器的一些基本功能:

@tail_recursive
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    """
    return True if n == 0 else odd(n - 1)

@tail_recursive
def odd(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> odd(100)
    False
    >>> odd(101)
    True
    """
    return False if n == 0 else even(n - 1)

@tail_recursive
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(30)
    265252859812191058636308480000000
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive
def fib(n, a = 0, b = 1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(20)
    >>> fib(30)
    832040
    """
    return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)

if __name__ == '__main__':
    import doctest
    doctest.testmod()

请注意,缓存在这些示例中不是很有用,以阶乘为例:fact(10)is never going to use fact(8), in fact

fact(8) fact(10)
事实(10, 1)
事实(9, 10)
事实(8, 1) 事实(8, 90)
... ...

累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。


更新 - 缓存优化

这是对原始答案中使用的缓存策略的部分修复。主要问题是考虑到通用尾递归算法的工作原理(参见阶乘示例),在缓存键中包含所有参数效率低下。

第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要低得多,但是测试使一切变得更加清晰:

class Logger:
    def __init__(self, name):
        self._name = name
        self._entries = []
    
    def log(self, s):
        self._entries.append(s)

    def print(self):
        log_prefix = f"[{self._name}] - "
        print(log_prefix + f"\n{log_prefix}".join(self._entries))

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
        get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
            functools._make_key(args, kwargs, False),\
        get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
            value):
    def decorator(f):
        f._first_call = True
        f._cache = {}

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            if f._first_call:
                f._new_args = args
                f._new_kwargs = kwargs
            
                try:
                    f._first_call = False
                    f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
                    while True:
                        cache_key = get_cache_key(f._new_args, f._new_kwargs)
                        if cache_key in f._cache:
                            logger.log('cache hit for ' + str(cache_key))
                            return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)

                        result = f(*f._new_args, **f._new_kwargs)

                        if not isinstance(result, TailRecArguments):
                            f._cache[f._initial_key] = result

                        if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                            f._new_args = result.args
                            f._new_kwargs = result.kwargs
                        else:
                            break

                    return result
                finally:
                    f._first_call = True
            else:
                return TailRecArguments(f, args, kwargs)

        return wrapper
    return decorator

除了Logger仅用于确认缓存命中的类之外,主要区别在于每个函数现在都有一个名为 的新成员_initial_key,它存储第一次调用的键。这样,如果我调用fact(5),5就变成了_initial_key并且结果被放入f._cache[5]

这可以优化相互递归和尾递归函数,但在某些情况下无效。让我们从最好的情况开始:

fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
    get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(5)
    120
    >>> fact(30)
    265252859812191058636308480000000
    >>> fact_logger.print()
    [fact] - cache hit for 5
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive装饰器初始化包括(记录器),它get_cache_key指定只有第一个参数n应该是缓存键的一部分,并get_result_after_cache_hit指定在缓存命中后如何产生最终结果。在上述情况下,当fact(30)达到时fact(5, <partial_factorial>),结果立即计算为<partial_factorial> * f._cache[5]

也是如此even-odd,除了在这种情况下,默认参数tail_recursive绰绰有余:

even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    >>> even(104)
    True
    >>> even_logger.print()
    [even] - cache hit for 100
    """
    return True if n == 0 else odd(n - 1)

不幸的是,这不适用于例如斐波那契函数。您应该通过在每次调用期间打印参数来轻松地说服自己,结果如下:

30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...

建立缓存键规则需要一个更复杂的逻辑,这可能会使tail_recursive装饰器变得非常不可读且可移植性较差。

于 2022-02-16T13:54:06.863 回答