CommencerCommencer gratuitement

Cibles Q fixes

Vous vous apprêtez à entraîner votre Lunar Lander avec des cibles Q fixes. Au préalable, vous devez instancier à la fois le réseau en ligne (qui choisit l'action) et le réseau cible (utilisé pour le calcul de la cible TD).

Vous devez également implémenter une fonction d'update_target_network que vous pourrez utiliser à chaque étape de l'entraînement. Le réseau cible n'est pas mis à jour par descente de gradient ; à la place, l'update_target_network e ajuste légèrement ses poids vers le réseau Q, garantissant ainsi qu'il reste relativement stable au fil du temps.

Veuillez noter que, pour cet exercice uniquement, vous utilisez un réseau très petit afin que nous puissions facilement imprimer et inspecter son dictionnaire d'états. Il ne comporte qu'une seule couche cachée de taille deux ; son espace d'action et son espace d'état sont également de dimension 2.

La fonction print_state_dict() est disponible dans votre environnement pour imprimer le dictionnaire d'état.

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Veuillez obtenir l'adresse IP ( .state_dict() ) des réseaux cible et en ligne.
  • Mettez à jour le dictionnaire d'états pour le réseau cible en calculant la moyenne pondérée entre les paramètres du réseau en ligne et ceux du réseau cible, en utilisant l'tau e comme poids pour le réseau en ligne.
  • Rechargez le dictionnaire d'état mis à jour sur le réseau cible.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Modifier et exécuter le code