DQN mit priorisierter Erfahrungswiedergabe
In dieser Übung führst du Prioritized Experience Replay (PER) ein, um den DQN Algorithmus zu verbessern. PER zielt darauf ab, die Menge der Übergänge zu optimieren, die bei jedem Schritt zur Aktualisierung des Netzes ausgewählt werden.
Die Methodennamen, die du für PrioritizedReplayBuffer
angegeben hast, lauten wie folgt:
push()
(um Übergänge in den Puffer zu schieben)sample()
(um einen Stapel von Übergängen aus dem Puffer zu entnehmen)increase_beta()
(um die Bedeutung der Stichproben zu erhöhen)update_priorities()
(um die gesampelten Prioritäten zu aktualisieren)
Die Funktion describe_episode()
wird wieder verwendet, um jede Episode zu beschreiben.
Diese Übung ist Teil des Kurses
Deep Reinforcement Learning in Python
Anleitung zur Übung
- Richte einen Puffer für priorisierte Erfahrungswiedergabe mit einer Kapazität von 10000 Übergängen ein.
- Erhöhe den Einfluss des Wichtigkeitssamplings im Laufe der Zeit, indem du den Parameter
beta
aktualisierst. - Aktualisiere die Priorität der gesampelten Erfahrungen auf der Grundlage ihres letzten TD Fehlers.
Interaktive Übung
Versuche dich an dieser Übung, indem du diesen Beispielcode vervollständigst.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)