Mise en œuvre de l'algorithme DQN complet
Le moment est enfin arrivé ! Toutes les conditions préalables sont remplies ; vous allez maintenant mettre en œuvre l'algorithme DQN complet et l'utiliser pour former un agent Lunar Lander. Cela signifie que votre algorithme utilisera non seulement la relecture de l'expérience, mais également la méthode epsilon-avide décroissante et les cibles Q fixes.
select_action() La fonction « decayedepsilongreedy» qui implémente l'algorithme Decayed Epsilon Greediness est à votre disposition, tout comme la fonction « update_target_network() » de l'exercice précédent. Il ne reste plus qu'à intégrer ces fonctions dans la boucle d'entraînement DQN et à s'assurer que vous utilisez correctement le réseau cible dans les calculs de perte.
Vous devez conserver un nouveau compteur de pas, total_steps, afin de réduire progressivement la valeur de \(\varepsilon\) au fil du temps. Cette variable est initialisée pour vous avec la valeur 0.
Cet exercice fait partie du cours
Apprentissage par renforcement profond en Python
Instructions
- Utilisez
select_action()pour implémenter l'algorithme Decayed Epsilon Greediness et sélectionner l'action de l'agent ; vous devrez utilisertotal_steps, le total cumulé sur l'ensemble des épisodes. - Avant de calculer l'objectif TD, désactivez le suivi de gradient.
- Après avoir obtenu l'état suivant, obtenez les valeurs Q de l'état suivant.
- Veuillez mettre à jour le réseau cible à la fin de chaque étape.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
for episode in range(10):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
# Select the action with epsilon greediness
action = ____(____, ____, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
# Ensure gradients are not tracked
with ____:
# Obtain the next state Q-values
next_q_values = ____(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
loss = nn.MSELoss()(q_values, target_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update the target network weights
____(____, ____, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)