Cálculo de las pérdidas del Actor Crítico
Como paso final antes de que puedas entrenar a tu agente con A2C, escribe una función calculate_losses()
que devuelva las pérdidas de ambas redes.
Como referencia, éstas son las expresiones de las funciones de pérdida del actor y del crítico, respectivamente:
Este ejercicio forma parte del curso
Aprendizaje profundo por refuerzo en Python
Instrucciones de ejercicio
- Calcula el objetivo TD.
- Calcula la pérdida de la red Actor.
- Calcula la pérdida de la red Critic.
Ejercicio interactivo práctico
Pruebe este ejercicio completando este código de muestra.
def calculate_losses(critic_network, action_log_prob,
reward, state, next_state, done):
value = critic_network(state)
next_value = critic_network(next_state)
# Calculate the TD target
td_target = (____ + gamma * ____ * (1-done))
td_error = td_target - value
# Calculate the actor loss
actor_loss = -____ * ____.detach()
# Calculate the critic loss
critic_loss = ____
return actor_loss, critic_loss
actor_loss, critic_loss = calculate_losses(
critic_network, action_log_prob,
reward, state, next_state, done
)
print(round(actor_loss.item(), 2), round(critic_loss.item(), 2))