ComenzarEmpieza gratis

Objetivos Q fijos

Te estás preparando para entrenar tu Lunar Lander con objetivos Q fijos. Como requisito previo, necesitas instanciar tanto la red en línea (que elige la acción) como la red objetivo (utilizada para el cálculo TD-objetivo).

También tienes que implementar una función update_target_network que puedas utilizar en cada paso del entrenamiento. La red objetivo no se actualiza mediante el descenso de gradiente; en su lugar, update_target_network empuja sus pesos hacia la red Q en una pequeña cantidad, asegurándose de que permanece bastante estable a lo largo del tiempo.

Ten en cuenta que, sólo para este ejercicio, utilizas una red muy pequeña para que podamos imprimir e inspeccionar fácilmente su diccionario de estado. Sólo tiene una capa oculta de tamaño dos; su espacio de acción y su espacio de estado también son de dimensión 2.

La función print_state_dict() está disponible en tu entorno para imprimir el dictado de estado.

Este ejercicio forma parte del curso

Aprendizaje profundo por refuerzo en Python

Ver curso

Instrucciones de ejercicio

  • Obtén el .state_dict() tanto de la red de destino como de la red online.
  • Actualiza el dictado de estado para la red objetivo tomando la media ponderada entre los parámetros de la red en línea y de la red objetivo, utilizando tau como peso para la red en línea.
  • Vuelve a cargar el dictado de estado actualizado en la red de destino.

Ejercicio interactivo práctico

Pruebe este ejercicio completando este código de muestra.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Editar y ejecutar código