Fixed Q-targets
You are preparing to train your Lunar Lander with fixed Q-targets. As a prerequisite, you need to instantiate both the online network (which chooses the action) and the target network (used for TD-target calculation).
You also need to implement an update_target_network
function which you can use at each training step. The target network is not updated by gradient descent; instead, update_target_network
nudges its weights towards the Q-network by a small amount, ensuring that it remains quite stable over time.
Note that, for this exercise only, you use a very small network so we can easily print and inspect its state dictionary. It has only one hidden layer of size two; its action space and state space are also of dimension 2.
The function print_state_dict()
is available in your environment to print the state dict.
This exercise is part of the course
Deep Reinforcement Learning in Python
Exercise instructions
- Obtain the
.state_dict()
for both the target and online networks. - Update the state dict for the target network by taking the weighted average between the parameters of the online network and of the target network, using
tau
as weight for the online network. - Load the updated state dict back onto the target network.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))