IniziaInizia gratis

DQN con prioritized experience replay

In questo esercizio introdurrai il Prioritized Experience Replay (PER) per migliorare l'algoritmo DQN. L'obiettivo del PER è ottimizzare il batch di transizioni selezionato per aggiornare la rete a ogni passo.

Per riferimento, i nomi dei metodi che hai definito per PrioritizedReplayBuffer sono:

  • push() (per inserire transizioni nel buffer)
  • sample() (per campionare un batch di transizioni dal buffer)
  • increase_beta() (per aumentare l'importance sampling)
  • update_priorities() (per aggiornare le priorità campionate)

La funzione describe_episode() viene usata di nuovo per descrivere ogni episodio.

Questo esercizio fa parte del corso

Deep Reinforcement Learning in Python

Visualizza il corso

Istruzioni dell'esercizio

  • Istanzia un buffer di Prioritized Experience Replay con una capacità di 10000 transizioni.
  • Aumenta nel tempo l'influenza dell'importance sampling aggiornando il parametro beta.
  • Aggiorna la priorità delle esperienze campionate in base al loro ultimo errore TD.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Modifica ed esegui il codice