MulaiMulai sekarang secara gratis

DQN dengan prioritized experience replay

Dalam latihan ini, Anda akan menambahkan Prioritized Experience Replay (PER) untuk meningkatkan algoritma DQN. PER bertujuan mengoptimalkan himpunan transisi yang dipilih untuk memperbarui jaringan pada setiap langkah.

Sebagai acuan, nama metode yang telah Anda deklarasikan untuk PrioritizedReplayBuffer adalah:

  • push() (untuk mendorong transisi ke buffer)
  • sample() (untuk mengambil satu batch transisi dari buffer)
  • increase_beta() (untuk meningkatkan importance sampling)
  • update_priorities() (untuk memperbarui prioritas yang diambil sampelnya)

Fungsi describe_episode() kembali digunakan untuk mendeskripsikan setiap episode.

Latihan ini adalah bagian dari kursus

Deep Reinforcement Learning dengan Python

Lihat Kursus

Petunjuk latihan

  • Instansiasikan buffer Prioritized Experience Replay dengan kapasitas 10000 transisi.
  • Tingkatkan pengaruh importance sampling seiring waktu dengan memperbarui parameter beta.
  • Perbarui prioritas pengalaman yang diambil sampelnya berdasarkan kesalahan TD terbaru mereka.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Edit dan Jalankan Kode