我在 Python 中有一个树搜索实现,这只是我使用速度慢的一种方式。我怎样才能更快地运行它?我读过有 numba 但我无法理解它是如何工作的,它可以支持什么,什么不能。有人使用 numba 来加快树搜索吗?提前谢谢你!
def search(self, canonicalBoard, game_copy, rootplayer_index, max_in_prediction=False, callback_counter=0):
start_time = time.time()
counter = 0
s = self.game.stringRepresentation(canonicalBoard)
if s not in self.Es:
self.Es[s] = self.game.getGameEnded(game_copy, rootplayer_index) # 1 -> probably player index # getGameEnded gives reward
if self.Es[s]!=0:
# terminal
return self.Es[s]
if callback_counter == 0 and s not in self.Ps:
raw_pol = self.policy_net.predict(canonicalBoard)
v = self.value_net.predict(canonicalBoard)[0][0]
self.Ps[s] = raw_pol[0]
valids = game_copy.roundo.get_legal_moves(game_copy.roundo.players[rootplayer_index])
index_valids = [card.index for card in valids]
mask = np.array(index_valids)
mask = np.array([1 if y in mask else 0 for y in range(36)], dtype=bool)
self.Ps[s][~mask] = 0
self.Ps[s] /= sum(self.Ps[s])
valids = [1 if y in index_valids else 0 for y in range(36)] # check it out
self.Vs[s] = valids
self.Ns[s] = 0
return v # !!
if callback_counter > 0:
v = self.value_net.predict(canonicalBoard)[0][0]
return v
valids = self.Vs[s]
cur_best = -float('inf')
best_act = -1
....
else:
self.Qsa[(s,a)] = v
self.Nsa[(s,a)] = 1
self.Ns[s] += 1
return v