Buffer prioritized experience replay
Anda akan memperkenalkan kelas PrioritizedExperienceReplay, sebuah struktur data yang nantinya akan Anda gunakan untuk mengimplementasikan DQN dengan Prioritized Experience Replay.
PrioritizedExperienceReplay adalah penyempurnaan atas kelas ExperienceReplay yang sejauh ini Anda gunakan untuk melatih agen DQN. Prioritized experience replay buffer memastikan bahwa transisi yang diambil darinya lebih bernilai untuk dipelajari agen dibandingkan dengan pengambilan sampel seragam.
Untuk saat ini, implementasikan metode .__init__(), .push(), .update_priorities(), .increase_beta() dan .__len__(). Metode terakhir, .sample(), akan menjadi fokus pada latihan berikutnya.
Latihan ini adalah bagian dari kursus
Deep Reinforcement Learning dengan Python
Petunjuk latihan
- Dalam
.push(), inisialisasi prioritas transisi ke prioritas maksimum di buffer (atau 1 jika buffer kosong). - Dalam
.update_priorities(), atur prioritas ke nilai absolut dari TD error yang bersesuaian; tambahkanself.epsilonuntuk menutup kasus tepi. - Dalam
.increase_beta(), tingkatkan beta sebesarself.beta_increment; pastikanbetatidak pernah melebihi 1.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)