ComeçarComece de graça

DQN com reprodução de experiência priorizada

Neste exercício, você introduzirá o Prioritized Experience Replay (PER) para aprimorar o algoritmo DQN. O objetivo do PER é otimizar o lote de transições selecionadas para atualizar a rede em cada etapa.

Para referência, os nomes dos métodos que você declarou para PrioritizedReplayBuffer são:

  • push() (para enviar transições para o buffer)
  • sample() (para obter uma amostra de um lote de transições do buffer)
  • increase_beta() (para aumentar a importância da amostragem)
  • update_priorities() (para atualizar as prioridades da amostra)

A função describe_episode() é usada novamente para descrever cada episódio.

Este exercício faz parte do curso

Aprendizado por reforço profundo em Python

Ver curso

Instruções do exercício

  • Instanciar um buffer de repetição de experiência priorizada com capacidade de 10.000 transições.
  • Aumente a influência da amostragem de importância ao longo do tempo atualizando o parâmetro beta.
  • Atualize a prioridade das experiências amostradas com base em seu último erro TD.

Exercício interativo prático

Experimente este exercício completando este código de exemplo.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Editar e executar o código