CommencerCommencer gratuitement

Sélection d'actions DQN de base

La fonction « select_action() » permet à l'agent de sélectionner l'action ayant la valeur Q la plus élevée à chaque étape.

La fonction prend comme arguments le réseau Q et l'état actuel, et renvoie l'index de l'action ayant la valeur Q la plus élevée.

Le réseau Q est instancié sous le nom d'q_network, et un état aléatoire a été chargé dans votre environnement avec state = torch.rand(8) afin de vous fournir des exemples de données avec lesquels travailler.

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Calculez les valeurs Q correspondant à chaque action dans l'état fourni en argument.
  • Obtenez l'index correspondant à l'action ayant la valeur Q la plus élevée.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def select_action(q_network, state):
    # Calculate the Q-values
    q_values = ____
    print("Q-values:", [round(x, 2) for x in q_values.tolist()])
    # Obtain the action index with highest Q-value
    action = torch.____.item()
    print(f"Action selected: {action}, with q-value {q_values[action]:.2f}")
    return action

select_action(q_network, state)
Modifier et exécuter le code