CommencerCommencer gratuitement

Fonction de perte DQN de Barebone

La fonction d'select_action() étant désormais prête, il ne vous reste plus qu'une étape pour pouvoir former votre agent : vous allez maintenant implémenter l'calculate_loss().

calculate_loss() La fonction « network_loss» (perte du réseau) renvoie la perte du réseau pour chaque étape de l'épisode.

À titre indicatif, la perte est calculée comme suit :

Les données suivantes ont été chargées dans l'exercice :

state = torch.rand(8)
next_state = torch.rand(8)
action = select_action(q_network, state)
reward = 1
gamma = .99
done = False

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Obtenir la valeur Q de l'état actuel.
  • Obtenir la valeur Q de l'état suivant.
  • Calculez la valeur Q cible, ou cible TD.
  • Calculez la fonction de perte, c'est-à-dire l'erreur de Bellman au carré.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def calculate_loss(q_network, state, action, next_state, reward, done):
    q_values = q_network(state)
    print(f'Q-values: {q_values}')
    # Obtain the current state Q-value
    current_state_q_value = q_values[____]
    print(f'Current state Q-value: {current_state_q_value:.2f}')
    # Obtain the next state Q-value
    next_state_q_value = q_network(next_state).____    
    print(f'Next state Q-value: {next_state_q_value:.2f}')
    # Calculate the target Q-value
    target_q_value = ____ + gamma * ____ * (1-done)
    print(f'Target Q-value: {target_q_value:.2f}')
    # Obtain the loss
    loss = nn.MSELoss()(____, ____)
    print(f'Loss: {loss:.2f}')
    return loss

calculate_loss(q_network, state, action, next_state, reward, done)
Modifier et exécuter le code