Q-target fissi
Stai per addestrare il tuo Lunar Lander usando Q-target fissi. Come prerequisito, devi istanziare sia la rete online (che sceglie l’azione) sia la rete target (usata per il calcolo del TD-target).
Devi anche implementare una funzione update_target_network che potrai usare a ogni step di training. La rete target non viene aggiornata con la discesa del gradiente; invece, update_target_network spinge i suoi pesi verso quelli della Q-network di una piccola quantità, garantendo che rimanga abbastanza stabile nel tempo.
Nota che, solo per questo esercizio, usi una rete molto piccola così possiamo stampare e ispezionare facilmente il suo state dictionary. Ha un solo livello nascosto di dimensione due; anche lo spazio delle azioni e lo spazio degli stati hanno dimensione 2.
La funzione print_state_dict() è disponibile nel tuo ambiente per stampare lo state dict.
Questo esercizio fa parte del corso
Deep Reinforcement Learning in Python
Istruzioni dell'esercizio
- Ottieni lo
.state_dict()sia per la rete target sia per la rete online. - Aggiorna lo state dict della rete target facendo la media pesata tra i parametri della rete online e quelli della rete target, usando
taucome peso per la rete online. - Carica lo state dict aggiornato nella rete target.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))