DQN com reprodução de experiência priorizada
Neste exercício, você introduzirá o Prioritized Experience Replay (PER) para aprimorar o algoritmo DQN. O objetivo do PER é otimizar o lote de transições selecionadas para atualizar a rede em cada etapa.
Para referência, os nomes dos métodos que você declarou para PrioritizedReplayBuffer
são:
push()
(para enviar transições para o buffer)sample()
(para obter uma amostra de um lote de transições do buffer)increase_beta()
(para aumentar a importância da amostragem)update_priorities()
(para atualizar as prioridades da amostra)
A função describe_episode()
é usada novamente para descrever cada episódio.
Este exercício faz parte do curso
Aprendizado por reforço profundo em Python
Instruções do exercício
- Instanciar um buffer de repetição de experiência priorizada com capacidade de 10.000 transições.
- Aumente a influência da amostragem de importância ao longo do tempo atualizando o parâmetro
beta
. - Atualize a prioridade das experiências amostradas com base em seu último erro TD.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)