Cibles Q fixes
Vous vous apprêtez à entraîner votre Lunar Lander avec des cibles Q fixes. Au préalable, vous devez instancier à la fois le réseau en ligne (qui choisit l'action) et le réseau cible (utilisé pour le calcul de la cible TD).
Vous devez également implémenter une fonction d'update_target_network que vous pourrez utiliser à chaque étape de l'entraînement. Le réseau cible n'est pas mis à jour par descente de gradient ; à la place, l'update_target_network e ajuste légèrement ses poids vers le réseau Q, garantissant ainsi qu'il reste relativement stable au fil du temps.
Veuillez noter que, pour cet exercice uniquement, vous utilisez un réseau très petit afin que nous puissions facilement imprimer et inspecter son dictionnaire d'états. Il ne comporte qu'une seule couche cachée de taille deux ; son espace d'action et son espace d'état sont également de dimension 2.
La fonction print_state_dict() est disponible dans votre environnement pour imprimer le dictionnaire d'état.
Cet exercice fait partie du cours
Apprentissage par renforcement profond en Python
Instructions
- Veuillez obtenir l'adresse IP (
.state_dict()) des réseaux cible et en ligne. - Mettez à jour le dictionnaire d'états pour le réseau cible en calculant la moyenne pondérée entre les paramètres du réseau en ligne et ceux du réseau cible, en utilisant l'
taue comme poids pour le réseau en ligne. - Rechargez le dictionnaire d'état mis à jour sur le réseau cible.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))