Aan de slagGa gratis aan de slag

Prioritized experience replay-buffer

Je gaat de klasse PrioritizedExperienceReplay introduceren, een datastructuur die je later gebruikt om DQN met Prioritized Experience Replay te implementeren.

PrioritizedExperienceReplay is een verfijning van de klasse ExperienceReplay die je tot nu toe hebt gebruikt om je DQN-agents te trainen. Een prioritized experience replay-buffer zorgt ervoor dat de getrokken transities waardevoller zijn voor het leerproces van de agent dan bij uniforme sampling.

Implementeer nu de methoden .__init__(), .push(), .update_priorities(), .increase_beta() en .__len__(). De laatste methode, .sample(), staat centraal in de volgende oefening.

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Cursus bekijken

Oefeninstructies

  • Stel in .push() de prioriteit van de transitie in op de maximale prioriteit in de buffer (of 1 als de buffer leeg is).
  • Stel in .update_priorities() de prioriteit gelijk aan de absolute waarde van de bijbehorende TD-fout; voeg self.epsilon toe om randgevallen af te dekken.
  • Verhoog in .increase_beta() beta met self.beta_increment; zorg ervoor dat beta nooit boven 1 uitkomt.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Code bewerken en uitvoeren