CommencerCommencer gratuitement

DQN avec relecture d'expérience priorisée

Dans cet exercice, vous allez introduire la technique de rejouer les expériences prioritaires (Prioritized Experience Replay, PER) afin d'améliorer l'algorithme DQN. Le PER vise à optimiser le lot de transitions sélectionnées pour mettre à jour le réseau à chaque étape.

À titre de référence, les noms des méthodes que vous avez déclarées pour l'PrioritizedReplayBuffer sont les suivants :

  • push() (pour transférer les transitions vers la mémoire tampon)
  • sample() (pour échantillonner un lot de transitions à partir de la mémoire tampon)
  • increase_beta() (pour augmenter l'échantillonnage par importance)
  • update_priorities() (pour mettre à jour les priorités échantillonnées)

La fonction « describe_episode() » est à nouveau utilisée pour décrire chaque épisode.

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Instancier un tampon de relecture d'expérience priorisée avec une capacité de 10 000 transitions.
  • Augmentez l'influence de l'échantillonnage pondéré au fil du temps en mettant à jour le paramètre « beta ».
  • Mettre à jour la priorité des expériences échantillonnées en fonction de leur dernière erreur TD.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Modifier et exécuter le code