DQN dengan prioritized experience replay
Dalam latihan ini, Anda akan menambahkan Prioritized Experience Replay (PER) untuk meningkatkan algoritma DQN. PER bertujuan mengoptimalkan himpunan transisi yang dipilih untuk memperbarui jaringan pada setiap langkah.
Sebagai acuan, nama metode yang telah Anda deklarasikan untuk PrioritizedReplayBuffer adalah:
push()(untuk mendorong transisi ke buffer)sample()(untuk mengambil satu batch transisi dari buffer)increase_beta()(untuk meningkatkan importance sampling)update_priorities()(untuk memperbarui prioritas yang diambil sampelnya)
Fungsi describe_episode() kembali digunakan untuk mendeskripsikan setiap episode.
Latihan ini adalah bagian dari kursus
Deep Reinforcement Learning dengan Python
Petunjuk latihan
- Instansiasikan buffer Prioritized Experience Replay dengan kapasitas 10000 transisi.
- Tingkatkan pengaruh importance sampling seiring waktu dengan memperbarui parameter
beta. - Perbarui prioritas pengalaman yang diambil sampelnya berdasarkan kesalahan TD terbaru mereka.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)