1

当使用像 lru_cache 这样的 functools 缓存函数时,内部函数不会更新非局部变量的值。同样的方法在没有装饰器的情况下也有效。

使用缓存装饰器时非局部变量是否没有更新?另外,如果我必须更新非局部变量但还要存储结果以避免重复工作,该怎么办?或者我是否需要从缓存函数中返回答案?

例如。以下未正确更新非局部变量的值

def foo(x):
    outer_var=0

    @lru_cache
    def bar(i):
        nonlocal outer_var
        if condition:
            outer_var+=1
        else:
            bar(i+1)

    bar(x)
    return outer_var

背景

我正在尝试解码方式问题,该问题正在寻找可以将一串数字解释为字母的方式数量。我从第一个字母开始,采取一两个步骤并检查它们是否有效。到达字符串末尾时,我更新了一个非局部变量,该变量存储了可能的方式数。此方法在不使用 lru_cache 的情况下给出正确答案,但在使用缓存时失败。我返回值的另一种方法正在工作,但我想检查如何在使用记忆装饰器时更新非局部变量。

我的错误代码:

ways=0
@lru_cache(None) # works well without this
def recurse(i):
    nonlocal ways
    if i==len(s):
        ways+=1
    elif i<len(s):
        if 1<=int(s[i])<=9:
            recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            recurse(i+2)
    return 

recurse(0)
return ways

公认的解决方案:

@lru_cache(None)
def recurse(i):
    if i==len(s):
        return 1

    elif i<len(s):
        ans=0
        if 1<=int(s[i])<=9:
            ans+= recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            ans+= recurse(i+2)
        return ans

return recurse(0)
4

1 回答 1

1

本身没有什么特别之处lru_cache,一个变量或递归会导致这里的任何固有问题。nonlocal这个问题纯粹是逻辑上的,而不是行为异常。看这个最小的例子:

from functools import lru_cache

def foo():
    c = 0

    @lru_cache(None)
    def bar(i=0):
        nonlocal c

        if i < 5:
            c += 1
            bar(i + 1)

    bar()
    return c

print(foo()) # => 5

解码方式代码的缓存版本中的问题是由于递归调用的重叠性质。缓存可防止 base case 调用recurse(i)wherei == len(s)多次执行,即使它是从不同的递归路径到达的。

建立这一点的一个好方法是print("hello")在基本案例(if i == len(s)分支)中打 a,然后给它一个相当大的问题。你会看到print("hello")一次,而且只有一次,并且由于除了通过whenways之外无法通过任何其他方式更新,所以当一切都说完了,你就剩下了。recurse(i)i == len(s)ways == 1

在上面的玩具示例中,只有一个递归路径:调用i在 0 到 9 之间扩展,并且从不使用缓存。相比之下,解码方式提供了多个递归路径,因此路径通过recurse(i+1)线性找到基本情况,然后随着堆栈展开,recurse(i+2)尝试找到其他方式来达到它。

添加缓存会切断额外的路径,但对于每个中间节点没有返回值。使用缓存,就像您有一个子问题的记忆或动态编程表,但您从不更新任何条目,因此整个表为零(基本情况除外)。

这是缓存导致的线性行为的示例:

from functools import lru_cache

def cached():
    @lru_cache(None)
    def cached_recurse(i=0):
        print("cached", i)

        if i < 3:
            cached_recurse(i + 1)
            cached_recurse(i + 2)

    cached_recurse()

def uncached():
    def uncached_recurse(i=0):
        print("uncached", i)

        if i < 3:
            uncached_recurse(i + 1)
            uncached_recurse(i + 2)

    uncached_recurse()

cached()
uncached()

输出:

cached 0
cached 1
cached 2
cached 3
cached 4
uncached 0
uncached 1
uncached 2
uncached 3
uncached 4
uncached 3
uncached 2
uncached 3
uncached 4

解决方案与您展示的完全一样:将结果向上传递并使用缓存来存储代表子问题的每个节点的值。这是两全其美的:我们有子问题的值,但没有重新执行最终导致您的ways += 1基本情况的函数。

换句话说,如果您要使用缓存,请将其视为查找表,而不仅仅是调用树修剪器。在您的尝试中,它不记得做了什么工作,只是阻止它再次完成。

于 2021-08-28T19:27:11.420 回答