DQN avec relecture d'expérience priorisée
Dans cet exercice, vous allez introduire la technique de rejouer les expériences prioritaires (Prioritized Experience Replay, PER) afin d'améliorer l'algorithme DQN. Le PER vise à optimiser le lot de transitions sélectionnées pour mettre à jour le réseau à chaque étape.
À titre de référence, les noms des méthodes que vous avez déclarées pour l'PrioritizedReplayBuffer sont les suivants :
push()(pour transférer les transitions vers la mémoire tampon)sample()(pour échantillonner un lot de transitions à partir de la mémoire tampon)increase_beta()(pour augmenter l'échantillonnage par importance)update_priorities()(pour mettre à jour les priorités échantillonnées)
La fonction « describe_episode() » est à nouveau utilisée pour décrire chaque épisode.
Cet exercice fait partie du cours
Apprentissage par renforcement profond en Python
Instructions
- Instancier un tampon de relecture d'expérience priorisée avec une capacité de 10 000 transitions.
- Augmentez l'influence de l'échantillonnage pondéré au fil du temps en mettant à jour le paramètre «
beta». - Mettre à jour la priorité des expériences échantillonnées en fonction de leur dernière erreur TD.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)