IniziaInizia gratis

Q-target fissi

Stai per addestrare il tuo Lunar Lander usando Q-target fissi. Come prerequisito, devi istanziare sia la rete online (che sceglie l’azione) sia la rete target (usata per il calcolo del TD-target).

Devi anche implementare una funzione update_target_network che potrai usare a ogni step di training. La rete target non viene aggiornata con la discesa del gradiente; invece, update_target_network spinge i suoi pesi verso quelli della Q-network di una piccola quantità, garantendo che rimanga abbastanza stabile nel tempo.

Nota che, solo per questo esercizio, usi una rete molto piccola così possiamo stampare e ispezionare facilmente il suo state dictionary. Ha un solo livello nascosto di dimensione due; anche lo spazio delle azioni e lo spazio degli stati hanno dimensione 2.

La funzione print_state_dict() è disponibile nel tuo ambiente per stampare lo state dict.

Questo esercizio fa parte del corso

Deep Reinforcement Learning in Python

Visualizza il corso

Istruzioni dell'esercizio

  • Ottieni lo .state_dict() sia per la rete target sia per la rete online.
  • Aggiorna lo state dict della rete target facendo la media pesata tra i parametri della rete online e quelli della rete target, usando tau come peso per la rete online.
  • Carica lo state dict aggiornato nella rete target.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Modifica ed esegui il codice