DQN with prioritized experience replay
In this exercise, you will introduce Prioritized Experience Replay (PER) to improve the DQN algorithm. PER aims to optimize the batch of transitions selected to update the network at each step.
For reference, the method names you have declared for PrioritizedReplayBuffer
are:
push()
(to push transitions to the buffer)sample()
(to sample a batch of transitions from the buffer)increase_beta()
(to increase importance sampling)update_priorities()
(to update the sampled priorities)
The describe_episode()
function is used again to describe each episode.
This exercise is part of the course
Deep Reinforcement Learning in Python
Exercise instructions
- Instantiate a Prioritized Experience Replay buffer with a capacity of 10000 transitions.
- Increase the influence of importance sampling over time by updating the
beta
parameter. - Update the priority of the sampled experiences based on their latest TD error.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)