1

我有一个整数数组self.N,我正在尝试写self.N[node] +=1,但是每当我写self.N[node]它时,它都会给我一个值错误,因为它有多个元素,它不能。


def __init__(self, exploration_weight=1):
    self.Q = defaultdict(int)  # total reward of each node
    self.N = defaultdict(int)  # total visit count for each node
    self.children = dict()  # children of each node


def do_rollout(self, node, player):
    "Make the tree one layer better. (Train for one iteration.)"
    print("in rollout")

    "Find an unexplored descendent of `node`"
    path = self._select(node)
    leaf = path[-1]

    "Update the `children` dict with the children of `node`"
    if leaf not in self.children:
        self.children[node] = node.find_children()

    "Send the reward back up to the ancestors of the leaf"
    for node in reversed(path):
        self.N[node] += 1


 def _select(self, node):
        "Find an unexplored descendent of `node`"
        path = []
        while True:
            path.append(node)
            if node not in self.children or not self.children[node]:
                # node is either unexplored or terminal
                return path
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            node = self.children[node]  # descend a layer deeper

我希望这self.N[node] += 1只会增加self.Nat index的值node

我在调试器中发现键似乎有问题,但我不知道是什么。


    (Pdb) self.children[leaf]

    turn:2
        |   |   |   |   |   |   |   |
        |   |   |   |   |   |   |   |
        |   |   |   |   |   |   |   |
        |   |   |   |   |   |   |   |
        |   |   |   |   |   |   |   |
        |   |   |   |   |   | X | O |
          _   _   _   _   _   _   _
          0   1   2   3   4   5   6
    *** KeyError: turn: 2, done False, winner: None

并且节点比较器似乎也不起作用,但我不知道为什么:

我还发现,就在它引发错误之前,调试器说:


    (pdb) p self.children.keys()
    *** TypeError: 'Node' object is not iterable

尽管它显然一直有效到这一点


Traceback (most recent call last):
  File "test_MCTS.py", line 52, in <module>
    agent_wins += play_bot()
  File "test_MCTS.py", line 18, in play_bot
    tree.do_rollout(board, 0) # player 0 is 2nd
  File "/Users/TorSaxberg/.../MCTS_minimal.py", line 50, in do_rollout
    self.N[node] += 1
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

board 是一个包含 [6,7] 数组和填充它的方法的节点

一个最小的例子:


    from random import randint
    from collections import defaultdict

    Q = defaultdict(int)  # total reward of each node
    N = defaultdict(int)  # total visit count for each node
    children = dict()  # children of each node


    def do_rollout(num):
        "Make the tree one layer better. (Train for one iteration.)"
        print("in rollout")

        "Find an unexplored descendent of `node`"
        path = _select(num)
        leaf = path[-1]

        "Update the `children` dict with the children of `node`"
        if leaf not in children: # a dict()
            children[num] = randint(0,5)

        "Send the reward back up to the ancestors of the leaf"
        for num in reversed(path):
            N[num] += 1 # a dict()

    def _select(num):
        "Find an unexplored descendent of `node`"
        path = []
        while True:
            path.append(num)
            if num not in children or not children[num]:
                return path
            breakpoint()
            unexplored = children[num] - children.keys() # a set()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            # descend a layer deeper
            num = children[randint(0,5)]

    num = randint(0,5)
    for _ in range(10):
        do_rollout(num)

但我无法通过另一个 TyepError 来显示上面的错误


    Traceback (most recent call last):
      File "test_ValueError.py", line 43, in <module>
        do_rollout(num)
      File "test_ValueError.py", line 14, in do_rollout
        path = _select(num)
      File "test_ValueError.py", line 33, in _select
        unexplored = children[num] - children.keys() # a set()
    TypeError: 'int' object is not iterable

这很奇怪,因为节点也不是可迭代的(来自调试) TypeError: 'Node' object is not iterable

4

0 回答 0