Calcul des pertes pour les acteurs critiques
Avant de pouvoir entraîner votre agent avec A2C, veuillez écrire une fonction d'calculate_losses() qui renvoie les pertes pour les deux réseaux.
À titre de référence, voici les expressions correspondant respectivement aux fonctions de perte de l'acteur et du critique :
Cet exercice fait partie du cours
Apprentissage par renforcement profond en Python
Instructions
- Calculez la cible TD.
- Veuillez calculer la perte pour le réseau Actor.
- Veuillez calculer la perte pour le réseau Critic.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def calculate_losses(critic_network, action_log_prob,
reward, state, next_state, done):
value = critic_network(state)
next_value = critic_network(next_state)
# Calculate the TD target
td_target = (____ + gamma * ____ * (1-done))
td_error = td_target - value
# Calculate the actor loss
actor_loss = -____ * ____.detach()
# Calculate the critic loss
critic_loss = ____
return actor_loss, critic_loss
actor_loss, critic_loss = calculate_losses(
critic_network, action_log_prob,
reward, state, next_state, done
)
print(round(actor_loss.item(), 2), round(critic_loss.item(), 2))