Objetivos Q fijos
Te estás preparando para entrenar tu Lunar Lander con objetivos Q fijos. Como requisito previo, necesitas instanciar tanto la red en línea (que elige la acción) como la red objetivo (utilizada para el cálculo TD-objetivo).
También tienes que implementar una función update_target_network
que puedas utilizar en cada paso del entrenamiento. La red objetivo no se actualiza mediante el descenso de gradiente; en su lugar, update_target_network
empuja sus pesos hacia la red Q en una pequeña cantidad, asegurándose de que permanece bastante estable a lo largo del tiempo.
Ten en cuenta que, sólo para este ejercicio, utilizas una red muy pequeña para que podamos imprimir e inspeccionar fácilmente su diccionario de estado. Sólo tiene una capa oculta de tamaño dos; su espacio de acción y su espacio de estado también son de dimensión 2.
La función print_state_dict()
está disponible en tu entorno para imprimir el dictado de estado.
Este ejercicio forma parte del curso
Aprendizaje profundo por refuerzo en Python
Instrucciones de ejercicio
- Obtén el
.state_dict()
tanto de la red de destino como de la red online. - Actualiza el dictado de estado para la red objetivo tomando la media ponderada entre los parámetros de la red en línea y de la red objetivo, utilizando
tau
como peso para la red en línea. - Vuelve a cargar el dictado de estado actualizado en la red de destino.
Ejercicio interactivo práctico
Pruebe este ejercicio completando este código de muestra.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))