CommencerCommencer gratuitement

DQN with prioritized experience replay

In this exercise, you will introduce Prioritized Experience Replay (PER) to improve the DQN algorithm. PER aims to optimize the batch of transitions selected to update the network at each step.

For reference, the method names you have declared for PrioritizedReplayBuffer are:

  • push() (to push transitions to the buffer)
  • sample() (to sample a batch of transitions from the buffer)
  • increase_beta() (to increase importance sampling)
  • update_priorities() (to update the sampled priorities)

The describe_episode() function is used again to describe each episode.

Cet exercice fait partie du cours

Deep Reinforcement Learning in Python

Afficher le cours

Instructions

  • Instantiate a Prioritized Experience Replay buffer with a capacity of 10000 transitions.
  • Increase the influence of importance sampling over time by updating the beta parameter.
  • Update the priority of the sampled experiences based on their latest TD error.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Modifier et exécuter le code