Aan de slagGa gratis aan de slag

Gefixeerde Q-targets

Je gaat je Lunar Lander trainen met gefixeerde Q-targets. Als voorbereiding moet je zowel het online netwerk (dat de actie kiest) als het target-netwerk (gebruikt voor de TD-targetberekening) instantieren.

Je moet ook een functie update_target_network implementeren die je bij elke trainingsstap kunt gebruiken. Het target-netwerk wordt niet met gradient descent geüpdatet; in plaats daarvan duwt update_target_network de gewichten een klein stukje richting het Q-netwerk, zodat het over tijd stabiel blijft.

Let op: alleen voor deze oefening gebruik je een heel klein netwerk zodat we de state dictionary eenvoudig kunnen afdrukken en bekijken. Het heeft slechts één verborgen laag van grootte twee; de actieruimte en toestandsruimte hebben ook dimensie 2.

De functie print_state_dict() is beschikbaar in je omgeving om de state dict af te drukken.

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Cursus bekijken

Oefeninstructies

  • Haal de .state_dict() op voor zowel het target- als het online netwerk.
  • Werk de state dict voor het target-netwerk bij door het gewogen gemiddelde te nemen van de parameters van het online netwerk en het target-netwerk, waarbij je tau gebruikt als gewicht voor het online netwerk.
  • Laad de bijgewerkte state dict terug in het target-netwerk.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Code bewerken en uitvoeren