查看您共享的链接,有很多有趣的解决方案。我特别受此启发并改变了一些事情。回顾一下,您需要一个尾递归装饰器,它既可以缓存函数先前执行的结果,又支持相互递归(?)。还有另一个关于尾递归上下文中相互递归的有趣讨论,它可能会帮助您理解主要问题。
我已经编写了一个同时进行缓存和相互递归的装饰器:我认为它可以进一步简化/改进,但它适用于我选择的测试样本:
from collections import namedtuple
import functools
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
乍一看似乎很复杂,但它重用了链接中讨论的一些概念。
初始化
f._first_call = True
f._cache = {}
而不是像START
,CONTINUE
和那样的状态RETURN
,在这种情况下,我只需要区分 the_first_call
和以下状态。事实上,第一次调用函数后,下一次调用会返回一个TailRecArgument
存储参数的函数。
f._cache
是该特定功能的缓存。
尾递归
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
result = f(*f._new_args, **f._new_kwargs)
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
这个版本的尾递归如何工作?在while
循环中,在第一次调用装饰函数后,使用返回的新参数连续调用该函数。
我什么时候可以退出循环?一旦返回的值不是 type TailRecArguments
,这意味着最后一次函数调用没有递归调用自身,而是返回了一个实际值。在这种情况下,我只需要返回结果并设置f._first_call = True
。不幸的是,它比这复杂一点,因为它不适用于相互递归。这里的解决方法是存储TailRecArguments
甚至调用的函数。通过这种方式,我可以检查用于下一个循环的参数是用于同一个函数(result.wrapped_func == f
)还是另一个尾递归函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,而是我可以返回它们,因为它们肯定会在while
遇到的第一个尾递归函数的循环中执行。唯一的缺点是f._first_call
每次参数属于另一个函数时都会重置。
缓存
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我把它放在while
循环中。不可能,因为只有在 while 循环内,该函数才会被连续调用,我可以检查缓存命中。
cache_key
因为我使用了functools
模块的内部函数,所以我在创作上作弊了一点。它是@cache
同一个模块中的装饰器使用的,您可以使用
import inspect
import functools
print(inspect.getsource(functools._make_key))
还有其他方法可以从创建缓存键*args
,**kwargs
就像这个一样,它再次指向_make_key
. 为了让你的代码更稳定,当然要避免使用私有成员。
正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...
. 我想缓存值,而不是尾递归调用的参数。
(实际上,当递归调用返回实际值时,我认为您可以将所有内容临时存储TailRecArguments
在一个列表中,并在缓存中添加与该列表大小一样多的条目。这会使解决方案复杂化,但如果您仍然可以接受有性能问题。这可能会在相互递归的情况下引发一些错误,如果需要,我将继续处理)。
测试
这些是我用来测试装饰器的一些基本功能:
@tail_recursive
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
"""
return True if n == 0 else odd(n - 1)
@tail_recursive
def odd(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> odd(100)
False
>>> odd(101)
True
"""
return False if n == 0 else even(n - 1)
@tail_recursive
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(30)
265252859812191058636308480000000
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
def fib(n, a = 0, b = 1):
"""
>>> import sys
>>> sys.setrecursionlimit(20)
>>> fib(30)
832040
"""
return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)
if __name__ == '__main__':
import doctest
doctest.testmod()
请注意,缓存在这些示例中不是很有用,以阶乘为例:fact(10)
is never going to use fact(8)
, in fact
fact(8) |
fact(10) |
|
事实(10, 1) |
|
事实(9, 10) |
事实(8, 1) |
事实(8, 90) |
... |
... |
累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。
更新 - 缓存优化
这是对原始答案中使用的缓存策略的部分修复。主要问题是考虑到通用尾递归算法的工作原理(参见阶乘示例),在缓存键中包含所有参数效率低下。
第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要低得多,但是测试使一切变得更加清晰:
class Logger:
def __init__(self, name):
self._name = name
self._entries = []
def log(self, s):
self._entries.append(s)
def print(self):
log_prefix = f"[{self._name}] - "
print(log_prefix + f"\n{log_prefix}".join(self._entries))
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
functools._make_key(args, kwargs, False),\
get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
value):
def decorator(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
while True:
cache_key = get_cache_key(f._new_args, f._new_kwargs)
if cache_key in f._cache:
logger.log('cache hit for ' + str(cache_key))
return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[f._initial_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
return decorator
除了Logger
仅用于确认缓存命中的类之外,主要区别在于每个函数现在都有一个名为 的新成员_initial_key
,它存储第一次调用的键。这样,如果我调用fact(5)
,5
就变成了_initial_key
并且结果被放入f._cache[5]
。
这可以优化相互递归和尾递归函数,但在某些情况下无效。让我们从最好的情况开始:
fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(5)
120
>>> fact(30)
265252859812191058636308480000000
>>> fact_logger.print()
[fact] - cache hit for 5
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
装饰器初始化包括(记录器),它get_cache_key
指定只有第一个参数n
应该是缓存键的一部分,并get_result_after_cache_hit
指定在缓存命中后如何产生最终结果。在上述情况下,当fact(30)
达到时fact(5, <partial_factorial>)
,结果立即计算为<partial_factorial> * f._cache[5]
。
也是如此even-odd
,除了在这种情况下,默认参数tail_recursive
绰绰有余:
even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
>>> even(104)
True
>>> even_logger.print()
[even] - cache hit for 100
"""
return True if n == 0 else odd(n - 1)
不幸的是,这不适用于例如斐波那契函数。您应该通过在每次调用期间打印参数来轻松地说服自己,结果如下:
30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...
建立缓存键规则需要一个更复杂的逻辑,这可能会使tail_recursive
装饰器变得非常不可读且可移植性较差。