1

我在 Python 中有一个树搜索实现,这只是我使用速度慢的一种方式。我怎样才能更快地运行它?我读过有 numba 但我无法理解它是如何工作的,它可以支持什么,什么不能。有人使用 numba 来加快树搜索吗?提前谢谢你!

    def search(self, canonicalBoard, game_copy, rootplayer_index, max_in_prediction=False, callback_counter=0):
    start_time = time.time()
    counter = 0

    s = self.game.stringRepresentation(canonicalBoard)

    if s not in self.Es:
        self.Es[s] = self.game.getGameEnded(game_copy, rootplayer_index) # 1 -> probably player index  # getGameEnded gives reward 
    if self.Es[s]!=0:
        # terminal  
        return self.Es[s]

    if callback_counter == 0 and s not in self.Ps:
        raw_pol = self.policy_net.predict(canonicalBoard)
        v = self.value_net.predict(canonicalBoard)[0][0]
        self.Ps[s] = raw_pol[0]
        valids = game_copy.roundo.get_legal_moves(game_copy.roundo.players[rootplayer_index])
        index_valids = [card.index for card in valids]
        mask = np.array(index_valids)
        mask = np.array([1 if y in mask else 0 for y in range(36)], dtype=bool)
        self.Ps[s][~mask] = 0
        self.Ps[s] /= sum(self.Ps[s])
        valids = [1 if y in index_valids else 0 for y in range(36)] # check it out
        self.Vs[s] = valids
        self.Ns[s] = 0
        return v # !!

    if callback_counter > 0:
        v = self.value_net.predict(canonicalBoard)[0][0]
        return v



    valids = self.Vs[s]
    cur_best = -float('inf')
    best_act = -1

  ....

    else:
        self.Qsa[(s,a)] = v
        self.Nsa[(s,a)] = 1

    self.Ns[s] += 1
    return v
4

0 回答 0