Get startedGet started for free

DQN with prioritized experience replay

In this exercise, you will introduce Prioritized Experience Replay (PER) to improve the DQN algorithm. PER aims to optimize the batch of transitions selected to update the network at each step.

For reference, the method names you have declared for PrioritizedReplayBuffer are:

  • push() (to push transitions to the buffer)
  • sample() (to sample a batch of transitions from the buffer)
  • increase_beta() (to increase importance sampling)
  • update_priorities() (to update the sampled priorities)

The describe_episode() function is used again to describe each episode.

This exercise is part of the course

Deep Reinforcement Learning in Python

View Course

Exercise instructions

  • Instantiate a Prioritized Experience Replay buffer with a capacity of 10000 transitions.
  • Increase the influence of importance sampling over time by updating the beta parameter.
  • Update the priority of the sampled experiences based on their latest TD error.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Edit and Run Code