MulaiMulai sekarang secara gratis

Buffer prioritized experience replay

Anda akan memperkenalkan kelas PrioritizedExperienceReplay, sebuah struktur data yang nantinya akan Anda gunakan untuk mengimplementasikan DQN dengan Prioritized Experience Replay.

PrioritizedExperienceReplay adalah penyempurnaan atas kelas ExperienceReplay yang sejauh ini Anda gunakan untuk melatih agen DQN. Prioritized experience replay buffer memastikan bahwa transisi yang diambil darinya lebih bernilai untuk dipelajari agen dibandingkan dengan pengambilan sampel seragam.

Untuk saat ini, implementasikan metode .__init__(), .push(), .update_priorities(), .increase_beta() dan .__len__(). Metode terakhir, .sample(), akan menjadi fokus pada latihan berikutnya.

Latihan ini adalah bagian dari kursus

Deep Reinforcement Learning dengan Python

Lihat Kursus

Petunjuk latihan

  • Dalam .push(), inisialisasi prioritas transisi ke prioritas maksimum di buffer (atau 1 jika buffer kosong).
  • Dalam .update_priorities(), atur prioritas ke nilai absolut dari TD error yang bersesuaian; tambahkan self.epsilon untuk menutup kasus tepi.
  • Dalam .increase_beta(), tingkatkan beta sebesar self.beta_increment; pastikan beta tidak pernah melebihi 1.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Edit dan Jalankan Kode