DQN met prioritaire experience replay
In deze oefening voeg je Prioritized Experience Replay (PER) toe om het DQN-algoritme te verbeteren. PER probeert de batch transities die bij elke stap wordt geselecteerd om het netwerk te updaten, te optimaliseren.
Ter referentie, de methoden die je hebt gedefinieerd voor PrioritizedReplayBuffer zijn:
push()(om transities in de buffer te plaatsen)sample()(om een batch transities uit de buffer te halen)increase_beta()(om importance sampling te vergroten)update_priorities()(om de gesamplede prioriteiten bij te werken)
De functie describe_episode() wordt opnieuw gebruikt om elke episode te beschrijven.
Deze oefening maakt deel uit van de cursus
Deep Reinforcement Learning in Python
Oefeninstructies
- Maak een Prioritized Experience Replay-buffer met een capaciteit van 10000 transities.
- Vergroot de invloed van importance sampling in de tijd door de parameter
betate updaten. - Werk de prioriteit van de gesamplede ervaringen bij op basis van hun nieuwste TD-fout.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)