我的 IS-MCTS 实现总是选择 allin,我不知道为什么。也许你们可以帮助我?
我已经尝试将节点中保存的值从wins更改为value,这意味着获得的筹码数量,但也得到了不好的结果。该算法甚至输给了一个随机玩家并且只跟注玩家。
mcts 方法有什么问题吗?如果不是,它可能是 ucb1 方法或“节点”类。
import math
import random
def find_best_action(self, rootstate):
rootnode = Node()
for i in range(self.max_iterations):
node = rootnode
# Determinization
state = rootstate.reset()
state.randomize(state.get_current_player())
while not state.is_terminal():
untried_actions = node.get_untried_actions(state.get_valid_actions())
if not untried_actions == []: # if we can expand (i.e. state/node is non-terminal)
# Expansion
a = random.choice(untried_actions)
node = node.add_child(a, state.get_current_player()) # add child and descend tree
state.act(a)
else:
# Selection
node = node.select_child()
state.act(node.action)
# Rollout
while not state.is_terminal(): # while state is non-terminal
state.act(random.choice(state.get_valid_actions()))
# Backpropagation
while node is not None: # backpropagate from the expanded node and work back to the root node
node.update(state)
node = node.parent_node
return max(rootnode.child_nodes, key=lambda n: n.visits).action # return the action that was most visited
我想它必须是节点类中的某些东西,它选择了错误的孩子。
import math
import random
class Node:
def __init__(self, action=None, parent=None, acted_player=None):
self.action = action
self.parent_node = parent
self.child_nodes = []
self.wins = 0
self.visits = 0
self.acted_player = acted_player
def get_untried_actions(self, valid_actions):
tried_actions = [child.action for child in self.child_nodes]
return [action for action in valid_actions if action not in tried_actions]
def select_child(self, exploration=0.7):
# Get the child with the highest UCB score
c = max(self.child_nodes, key=lambda node: node.calc_ucb1_score(exploration))
return c
def add_child(self, a, p):
n = Node(action=a, parent=self, acted_player=p)
self.child_nodes.append(n)
return n
def update(self, terminal_state):
self.visits += 1
if self.acted_player is not None:
self.wins += terminal_state.get_result(self.acted_player)
def calc_ucb1_score(self, exploration):
if self.visits == 0:
return 0
else:
return self._calc_avg_wins() + exploration * sqrt(2 * log(self.parent_node.visits) / float(self.visits))
def _calc_avg_wins(self):
if self.wins == 0:
return 0.0
elif self.visits == 0:
return 0.0
else:
return float(self.wins) / float(self.visits)