CommencerCommencer gratuitement

Formation du double DQN

Vous allez maintenant modifier votre code DQN afin d'implémenter le double DQN.

Le Double DQN ne nécessite qu'un ajustement minimal de l'algorithme DQN, mais contribue grandement à résoudre le problème de surestimation de la valeur Q et offre souvent de meilleures performances que le DQN.

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Calculez les actions suivantes pour le calcul de la cible Q à l'aide de l'online_network(), en veillant à obtenir l'action et la forme appropriées.
  • Estimez les valeurs Q de ces actions à l'aide de l'target_network(), en veillant à nouveau à obtenir les valeurs et la forme correctes.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

for episode in range(10):
    state, info = env.reset()
    done = False
    step = 0
    episode_reward = 0
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                # Obtain next actions for Q-target calculation
                next_actions = ____.____.____
                # Estimate next Q-values from these actions
                next_q_values = ____.____.____
                target_q_values = rewards + gamma * next_q_values * (1-dones)
            loss = nn.MSELoss()(q_values, target_q_values)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Modifier et exécuter le code