Mulai sekarangMulai gratis

Buffer prioritized experience replay

Anda akan memperkenalkan kelas PrioritizedExperienceReplay, sebuah struktur data yang nantinya akan Anda gunakan untuk mengimplementasikan DQN dengan Prioritized Experience Replay.

PrioritizedExperienceReplay adalah penyempurnaan atas kelas ExperienceReplay yang sejauh ini Anda gunakan untuk melatih agen DQN. Prioritized experience replay buffer memastikan bahwa transisi yang diambil darinya lebih bernilai untuk dipelajari agen dibandingkan dengan pengambilan sampel seragam.

Untuk saat ini, implementasikan metode .__init__(), .push(), .update_priorities(), .increase_beta() dan .__len__(). Metode terakhir, .sample(), akan menjadi fokus pada latihan berikutnya.

Latihan ini merupakan bagian dari kursus

Deep Reinforcement Learning dengan Python

Lihat Kursus

Instruksi latihan

  • Dalam .push(), inisialisasi prioritas transisi ke prioritas maksimum di buffer (atau 1 jika buffer kosong).
  • Dalam .update_priorities(), atur prioritas ke nilai absolut dari TD error yang bersesuaian; tambahkan self.epsilon untuk menutup kasus tepi.
  • Dalam .increase_beta(), tingkatkan beta sebesar self.beta_increment; pastikan beta tidak pernah melebihi 1.

Latihan interaktif langsung praktik

Cobalah latihan ini dengan melengkapi kode contoh ini.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Edit dan Jalankan Kode