Prioritized experience replay-buffer
Je gaat de klasse PrioritizedExperienceReplay introduceren, een datastructuur die je later gebruikt om DQN met Prioritized Experience Replay te implementeren.
PrioritizedExperienceReplay is een verfijning van de klasse ExperienceReplay die je tot nu toe hebt gebruikt om je DQN-agents te trainen. Een prioritized experience replay-buffer zorgt ervoor dat de getrokken transities waardevoller zijn voor het leerproces van de agent dan bij uniforme sampling.
Implementeer nu de methoden .__init__(), .push(), .update_priorities(), .increase_beta() en .__len__(). De laatste methode, .sample(), staat centraal in de volgende oefening.
Deze oefening maakt deel uit van de cursus
Deep Reinforcement Learning in Python
Oefeninstructies
- Stel in
.push()de prioriteit van de transitie in op de maximale prioriteit in de buffer (of 1 als de buffer leeg is). - Stel in
.update_priorities()de prioriteit gelijk aan de absolute waarde van de bijbehorende TD-fout; voegself.epsilontoe om randgevallen af te dekken. - Verhoog in
.increase_beta()beta metself.beta_increment; zorg ervoor datbetanooit boven 1 uitkomt.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)