DQN con prioritized experience replay
In questo esercizio introdurrai il Prioritized Experience Replay (PER) per migliorare l'algoritmo DQN. L'obiettivo del PER è ottimizzare il batch di transizioni selezionato per aggiornare la rete a ogni passo.
Per riferimento, i nomi dei metodi che hai definito per PrioritizedReplayBuffer sono:
push()(per inserire transizioni nel buffer)sample()(per campionare un batch di transizioni dal buffer)increase_beta()(per aumentare l'importance sampling)update_priorities()(per aggiornare le priorità campionate)
La funzione describe_episode() viene usata di nuovo per descrivere ogni episodio.
Questo esercizio fa parte del corso
Deep Reinforcement Learning in Python
Istruzioni dell'esercizio
- Istanzia un buffer di Prioritized Experience Replay con una capacità di 10000 transizioni.
- Aumenta nel tempo l'influenza dell'importance sampling aggiornando il parametro
beta. - Aggiorna la priorità delle esperienze campionate in base al loro ultimo errore TD.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)