Get startedGet started for free

Fixed Q-targets

You are preparing to train your Lunar Lander with fixed Q-targets. As a prerequisite, you need to instantiate both the online network (which chooses the action) and the target network (used for TD-target calculation).

You also need to implement an update_target_network function which you can use at each training step. The target network is not updated by gradient descent; instead, update_target_network nudges its weights towards the Q-network by a small amount, ensuring that it remains quite stable over time.

Note that, for this exercise only, you use a very small network so we can easily print and inspect its state dictionary. It has only one hidden layer of size two; its action space and state space are also of dimension 2.

The function print_state_dict() is available in your environment to print the state dict.

This exercise is part of the course

Deep Reinforcement Learning in Python

View Course

Exercise instructions

  • Obtain the .state_dict() for both the target and online networks.
  • Update the state dict for the target network by taking the weighted average between the parameters of the online network and of the target network, using tau as weight for the online network.
  • Load the updated state dict back onto the target network.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Edit and Run Code