Barebone DQN función de pérdida

Con la función select_action() ya lista, sólo te falta un último paso para poder formar a tu agente: ahora pondrás en marcha calculate_loss().

La página calculate_loss() devuelve la pérdida de red para cualquier paso del episodio.

Como referencia, la pérdida viene dada por

En el ejercicio se han cargado los siguientes datos de ejemplo:

state = torch.rand(8)
next_state = torch.rand(8)
action = select_action(q_network, state)
reward = 1
gamma = .99
done = False

Este ejercicio forma parte del curso

Aprendizaje profundo por refuerzo en Python

Ver curso

Instrucciones de ejercicio

  • Obtener el valor Q del estado actual.
  • Obtén el valor Q del siguiente estado.
  • Calcula el valor Q objetivo, o TD-objetivo.
  • Calcula la función de pérdida, es decir, el Error de Bellman al cuadrado.

Ejercicio interactivo práctico

Pruebe este ejercicio completando este código de muestra.

def calculate_loss(q_network, state, action, next_state, reward, done):
    q_values = q_network(state)
    print(f'Q-values: {q_values}')
    # Obtain the current state Q-value
    current_state_q_value = q_values[____]
    print(f'Current state Q-value: {current_state_q_value:.2f}')
    # Obtain the next state Q-value
    next_state_q_value = q_network(next_state).____    
    print(f'Next state Q-value: {next_state_q_value:.2f}')
    # Calculate the target Q-value
    target_q_value = ____ + gamma * ____ * (1-done)
    print(f'Target Q-value: {target_q_value:.2f}')
    # Obtain the loss
    loss = nn.MSELoss()(____, ____)
    print(f'Loss: {loss:.2f}')
    return loss

calculate_loss(q_network, state, action, next_state, reward, done)