0

这是我的 DQN 实现的代码。我检查了许多其他人的存储库中的代码,但找不到任何差异。我想 learn() 中有一些错误,但我找不到任何差异。我在 pytorch 官方网站上查看了 DQN 代码,我查看了他们的动作、奖励、状态、下一个状态完全没问题。

这是我的代理类代码。

class DQNAgent():
    def __init__(self, net, capacity, n_actions, eps_start, eps_end, eps_decay, batch_size, gamma, lr):
        self.net = net
        self.target_net = copy.deepcopy(self.net)
        self.buffer =  ReplayBuffer(capacity)
        self.n_actions = n_actions
        self.device = next(net.parameters()).device
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.batch_size = batch_size
        self.gamma = gamma
        self.sample_step = 0   # for decaying epsilon-greedy policy
        self.loss_fn = nn.SmoothL1Loss()
        self.optim = torch.optim.Adam(self.net.parameters(), lr=lr)

    def store_transition(self, state, action, next_state, reward):
        self.buffer.push(state, action, next_state, reward)

    def select_action(self, state):
        '''
        Returns:
        - action: shape [bs, ]
        '''

        # Decaying epsilon
        eps = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1 * self.sample_step / self.eps_decay)
        self.sample_step += 1

        if random.random() > eps:  # greedy
            self.net.eval()
            with torch.no_grad():
                return self.net(state).argmax(-1).item()
        else:  # random
            return torch.randint(self.n_actions, (1,)).item()

    def learn(self):
        self.net.train()
        self.target_net.eval()

        transitions = self.buffer.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        
        non_final_mask = torch.tensor(tuple(map(lambda x: x is not None, batch.next_state)), device=self.device, dtype=torch.bool)
        non_final_next_states = torch.tensor([state for state in batch.next_state if state is not None], dtype=torch.float32, device=self.device)

        states = torch.stack(batch.state).to(self.device)  # [bs, input_dim]
        rewards = torch.tensor(batch.reward).to(self.device)  # [bs, ]
        actions = torch.tensor(batch.action).unsqueeze(-1).to(self.device)   # [bs, ] --> [bs, 1]

        Q = self.net(states).gather(1, actions)  # [bs, n_actions] --> [bs, 1]

        next_state_values = torch.zeros(self.batch_size, device=device)
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
        expected_state_action_values = (next_state_values * self.gamma) + rewards

        self.optim.zero_grad()
        loss = self.loss_fn(Q, expected_state_action_values.unsqueeze(1))
        loss.backward()
        self.optim.step()
        return loss.detach().item()

    def update_target_net(self):
        self.target_net.load_state_dict(self.net.state_dict())

这是我的训练循环代码。

reward_history = []
for episode_i in tqdm(range(max_episodes)):
    s = env.reset()
    if np.mean(reward_history[-5:]) >= max_steps*0.9:
        break
    ep_reward = 0
    while True:
        s = torch.tensor(s, dtype=torch.float32, device=device)
        a = agent.select_action(s)
        new_s, r, done, info = env.step(a)

        agent.store_transition(s, a, new_s, r)

        if done:
            reward_history.append(ep_reward)
            break

        s = new_s
        ep_reward += 1

        if len(agent.buffer) >= agent.batch_size:
            loss = agent.learn()
    
    print(f'Episode {episode_i} | Reward: {ep_reward}')

    if episode_i % target_update_intv == 0:
        agent.update_target_net()
4

0 回答 0