Feste Q-Ziele
Du bereitest dich darauf vor, deinen Lunar Lander mit festen Q-Zielen zu trainieren. Als Voraussetzung musst du sowohl das Online-Netz (das die Aktion auswählt) als auch das Zielnetz (für die TD-Zielberechnung) instanziieren.
Du musst auch eine update_target_network
Funktion implementieren, die du bei jedem Trainingsschritt verwenden kannst. Das Zielnetz wird nicht durch Gradientenabstieg aktualisiert. Stattdessen schiebt update_target_network
seine Gewichte um einen kleinen Betrag in Richtung des Q-Netzes und stellt so sicher, dass es im Laufe der Zeit recht stabil bleibt.
Beachte, dass du für diese Übung nur ein sehr kleines Netzwerk verwendest, damit wir sein Zustandswörterbuch leicht ausdrucken und untersuchen können. Es hat nur eine verborgene Schicht der Größe zwei; sein Aktionsraum und sein Zustandsraum sind ebenfalls von der Dimension 2.
Die Funktion print_state_dict()
ist in deiner Umgebung verfügbar, um den Status dict zu drucken.
Diese Übung ist Teil des Kurses
Deep Reinforcement Learning in Python
Anleitung zur Übung
- Besorge dir die
.state_dict()
sowohl für das Ziel- als auch für das Online-Netzwerk. - Aktualisiere das Zustandsdiktat für das Zielnetz, indem du den gewichteten Durchschnitt zwischen den Parametern des Online-Netzes und des Zielnetzes ermittelst und
tau
als Gewicht für das Online-Netz verwendest. - Lade den aktualisierten Status dict zurück in das Zielnetzwerk.
Interaktive Übung zum Anfassen
Probieren Sie diese Übung aus, indem Sie diesen Beispielcode ausführen.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))