我有一个整数数组self.N
,我正在尝试写self.N[node] +=1
,但是每当我写self.N[node]
它时,它都会给我一个值错误,因为它有多个元素,它不能。
def __init__(self, exploration_weight=1):
self.Q = defaultdict(int) # total reward of each node
self.N = defaultdict(int) # total visit count for each node
self.children = dict() # children of each node
def do_rollout(self, node, player):
"Make the tree one layer better. (Train for one iteration.)"
print("in rollout")
"Find an unexplored descendent of `node`"
path = self._select(node)
leaf = path[-1]
"Update the `children` dict with the children of `node`"
if leaf not in self.children:
self.children[node] = node.find_children()
"Send the reward back up to the ancestors of the leaf"
for node in reversed(path):
self.N[node] += 1
def _select(self, node):
"Find an unexplored descendent of `node`"
path = []
while True:
path.append(node)
if node not in self.children or not self.children[node]:
# node is either unexplored or terminal
return path
unexplored = self.children[node] - self.children.keys()
if unexplored:
n = unexplored.pop()
path.append(n)
return path
node = self.children[node] # descend a layer deeper
我希望这self.N[node] += 1
只会增加self.N
at index的值node
我在调试器中发现键似乎有问题,但我不知道是什么。
(Pdb) self.children[leaf]
turn:2
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | X | O |
_ _ _ _ _ _ _
0 1 2 3 4 5 6
*** KeyError: turn: 2, done False, winner: None
并且节点比较器似乎也不起作用,但我不知道为什么:
我还发现,就在它引发错误之前,调试器说:
(pdb) p self.children.keys()
*** TypeError: 'Node' object is not iterable
尽管它显然一直有效到这一点
Traceback (most recent call last):
File "test_MCTS.py", line 52, in <module>
agent_wins += play_bot()
File "test_MCTS.py", line 18, in play_bot
tree.do_rollout(board, 0) # player 0 is 2nd
File "/Users/TorSaxberg/.../MCTS_minimal.py", line 50, in do_rollout
self.N[node] += 1
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
board 是一个包含 [6,7] 数组和填充它的方法的节点
一个最小的例子:
from random import randint
from collections import defaultdict
Q = defaultdict(int) # total reward of each node
N = defaultdict(int) # total visit count for each node
children = dict() # children of each node
def do_rollout(num):
"Make the tree one layer better. (Train for one iteration.)"
print("in rollout")
"Find an unexplored descendent of `node`"
path = _select(num)
leaf = path[-1]
"Update the `children` dict with the children of `node`"
if leaf not in children: # a dict()
children[num] = randint(0,5)
"Send the reward back up to the ancestors of the leaf"
for num in reversed(path):
N[num] += 1 # a dict()
def _select(num):
"Find an unexplored descendent of `node`"
path = []
while True:
path.append(num)
if num not in children or not children[num]:
return path
breakpoint()
unexplored = children[num] - children.keys() # a set()
if unexplored:
n = unexplored.pop()
path.append(n)
return path
# descend a layer deeper
num = children[randint(0,5)]
num = randint(0,5)
for _ in range(10):
do_rollout(num)
但我无法通过另一个 TyepError 来显示上面的错误
Traceback (most recent call last):
File "test_ValueError.py", line 43, in <module>
do_rollout(num)
File "test_ValueError.py", line 14, in do_rollout
path = _select(num)
File "test_ValueError.py", line 33, in _select
unexplored = children[num] - children.keys() # a set()
TypeError: 'int' object is not iterable
这很奇怪,因为节点也不是可迭代的(来自调试)
TypeError: 'Node' object is not iterable