Erste SchritteKostenlos loslegen

Feste Q-Ziele

Du bereitest dich darauf vor, deinen Lunar Lander mit festen Q-Zielen zu trainieren. Als Voraussetzung musst du sowohl das Online-Netz (das die Aktion auswählt) als auch das Zielnetz (für die TD-Zielberechnung) instanziieren.

Du musst auch eine update_target_network Funktion implementieren, die du bei jedem Trainingsschritt verwenden kannst. Das Zielnetz wird nicht durch Gradientenabstieg aktualisiert. Stattdessen schiebt update_target_network seine Gewichte um einen kleinen Betrag in Richtung des Q-Netzes und stellt so sicher, dass es im Laufe der Zeit recht stabil bleibt.

Beachte, dass du für diese Übung nur ein sehr kleines Netzwerk verwendest, damit wir sein Zustandswörterbuch leicht ausdrucken und untersuchen können. Es hat nur eine verborgene Schicht der Größe zwei; sein Aktionsraum und sein Zustandsraum sind ebenfalls von der Dimension 2.

Die Funktion print_state_dict() ist in deiner Umgebung verfügbar, um den Status dict zu drucken.

Diese Übung ist Teil des Kurses

Deep Reinforcement Learning in Python

Kurs anzeigen

Anleitung zur Übung

  • Besorge dir die .state_dict() sowohl für das Ziel- als auch für das Online-Netzwerk.
  • Aktualisiere das Zustandsdiktat für das Zielnetz, indem du den gewichteten Durchschnitt zwischen den Parametern des Online-Netzes und des Zielnetzes ermittelst und tau als Gewicht für das Online-Netz verwendest.
  • Lade den aktualisierten Status dict zurück in das Zielnetzwerk.

Interaktive Übung zum Anfassen

Probieren Sie diese Übung aus, indem Sie diesen Beispielcode ausführen.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Bearbeiten und Ausführen von Code