这是我的 DQN 实现的代码。我检查了许多其他人的存储库中的代码,但找不到任何差异。我想 learn() 中有一些错误,但我找不到任何差异。我在 pytorch 官方网站上查看了 DQN 代码,我查看了他们的动作、奖励、状态、下一个状态完全没问题。
这是我的代理类代码。
class DQNAgent():
def __init__(self, net, capacity, n_actions, eps_start, eps_end, eps_decay, batch_size, gamma, lr):
self.net = net
self.target_net = copy.deepcopy(self.net)
self.buffer = ReplayBuffer(capacity)
self.n_actions = n_actions
self.device = next(net.parameters()).device
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.batch_size = batch_size
self.gamma = gamma
self.sample_step = 0 # for decaying epsilon-greedy policy
self.loss_fn = nn.SmoothL1Loss()
self.optim = torch.optim.Adam(self.net.parameters(), lr=lr)
def store_transition(self, state, action, next_state, reward):
self.buffer.push(state, action, next_state, reward)
def select_action(self, state):
'''
Returns:
- action: shape [bs, ]
'''
# Decaying epsilon
eps = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1 * self.sample_step / self.eps_decay)
self.sample_step += 1
if random.random() > eps: # greedy
self.net.eval()
with torch.no_grad():
return self.net(state).argmax(-1).item()
else: # random
return torch.randint(self.n_actions, (1,)).item()
def learn(self):
self.net.train()
self.target_net.eval()
transitions = self.buffer.sample(self.batch_size)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda x: x is not None, batch.next_state)), device=self.device, dtype=torch.bool)
non_final_next_states = torch.tensor([state for state in batch.next_state if state is not None], dtype=torch.float32, device=self.device)
states = torch.stack(batch.state).to(self.device) # [bs, input_dim]
rewards = torch.tensor(batch.reward).to(self.device) # [bs, ]
actions = torch.tensor(batch.action).unsqueeze(-1).to(self.device) # [bs, ] --> [bs, 1]
Q = self.net(states).gather(1, actions) # [bs, n_actions] --> [bs, 1]
next_state_values = torch.zeros(self.batch_size, device=device)
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * self.gamma) + rewards
self.optim.zero_grad()
loss = self.loss_fn(Q, expected_state_action_values.unsqueeze(1))
loss.backward()
self.optim.step()
return loss.detach().item()
def update_target_net(self):
self.target_net.load_state_dict(self.net.state_dict())
这是我的训练循环代码。
reward_history = []
for episode_i in tqdm(range(max_episodes)):
s = env.reset()
if np.mean(reward_history[-5:]) >= max_steps*0.9:
break
ep_reward = 0
while True:
s = torch.tensor(s, dtype=torch.float32, device=device)
a = agent.select_action(s)
new_s, r, done, info = env.step(a)
agent.store_transition(s, a, new_s, r)
if done:
reward_history.append(ep_reward)
break
s = new_s
ep_reward += 1
if len(agent.buffer) >= agent.batch_size:
loss = agent.learn()
print(f'Episode {episode_i} | Reward: {ep_reward}')
if episode_i % target_update_intv == 0:
agent.update_target_net()