Buffer prioritized experience replay
Anda akan memperkenalkan kelas PrioritizedExperienceReplay, sebuah struktur data yang nantinya akan Anda gunakan untuk mengimplementasikan DQN dengan Prioritized Experience Replay.
PrioritizedExperienceReplay adalah penyempurnaan atas kelas ExperienceReplay yang sejauh ini Anda gunakan untuk melatih agen DQN. Prioritized experience replay buffer memastikan bahwa transisi yang diambil darinya lebih bernilai untuk dipelajari agen dibandingkan dengan pengambilan sampel seragam.
Untuk saat ini, implementasikan metode .__init__(), .push(), .update_priorities(), .increase_beta() dan .__len__(). Metode terakhir, .sample(), akan menjadi fokus pada latihan berikutnya.
Latihan ini merupakan bagian dari kursus
Deep Reinforcement Learning dengan Python
Instruksi latihan
- Dalam
.push(), inisialisasi prioritas transisi ke prioritas maksimum di buffer (atau 1 jika buffer kosong). - Dalam
.update_priorities(), atur prioritas ke nilai absolut dari TD error yang bersesuaian; tambahkanself.epsilonuntuk menutup kasus tepi. - Dalam
.increase_beta(), tingkatkan beta sebesarself.beta_increment; pastikanbetatidak pernah melebihi 1.
Latihan interaktif langsung praktik
Cobalah latihan ini dengan melengkapi kode contoh ini.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)