CommencerCommencer gratuitement

Fixed Q-targets

You are preparing to train your Lunar Lander with fixed Q-targets. As a prerequisite, you need to instantiate both the online network (which chooses the action) and the target network (used for TD-target calculation).

You also need to implement an update_target_network function which you can use at each training step. The target network is not updated by gradient descent; instead, update_target_network nudges its weights towards the Q-network by a small amount, ensuring that it remains quite stable over time.

Note that, for this exercise only, you use a very small network so we can easily print and inspect its state dictionary. It has only one hidden layer of size two; its action space and state space are also of dimension 2.

The function print_state_dict() is available in your environment to print the state dict.

Cet exercice fait partie du cours

Deep Reinforcement Learning in Python

Afficher le cours

Instructions

  • Obtain the .state_dict() for both the target and online networks.
  • Update the state dict for the target network by taking the weighted average between the parameters of the online network and of the target network, using tau as weight for the online network.
  • Load the updated state dict back onto the target network.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Modifier et exécuter le code