Gefixeerde Q-targets
Je gaat je Lunar Lander trainen met gefixeerde Q-targets. Als voorbereiding moet je zowel het online netwerk (dat de actie kiest) als het target-netwerk (gebruikt voor de TD-targetberekening) instantieren.
Je moet ook een functie update_target_network implementeren die je bij elke trainingsstap kunt gebruiken. Het target-netwerk wordt niet met gradient descent geüpdatet; in plaats daarvan duwt update_target_network de gewichten een klein stukje richting het Q-netwerk, zodat het over tijd stabiel blijft.
Let op: alleen voor deze oefening gebruik je een heel klein netwerk zodat we de state dictionary eenvoudig kunnen afdrukken en bekijken. Het heeft slechts één verborgen laag van grootte twee; de actieruimte en toestandsruimte hebben ook dimensie 2.
De functie print_state_dict() is beschikbaar in je omgeving om de state dict af te drukken.
Deze oefening maakt deel uit van de cursus
Deep Reinforcement Learning in Python
Oefeninstructies
- Haal de
.state_dict()op voor zowel het target- als het online netwerk. - Werk de state dict voor het target-netwerk bij door het gewogen gemiddelde te nemen van de parameters van het online netwerk en het target-netwerk, waarbij je
taugebruikt als gewicht voor het online netwerk. - Laad de bijgewerkte state dict terug in het target-netwerk.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))