CommencerCommencer gratuitement

Prioritized experience replay buffer

You will introduce the PrioritizedExperienceReplay class, a data structure that you will later use to implement DQN with Prioritized Experience Replay.

PrioritizedExperienceReplay is a refinement over the ExperienceReplay class that you have been using so far to train your DQN agents. A prioritized experience replay buffer ensures that the transitions sampled from it are more valuable for the agent to learn from than with uniform sampling.

For now, implement the methods .__init__(), .push(), .update_priorities(), .increase_beta() and .__len__(). The final method, .sample(), will be the focus of the next exercise.

Cet exercice fait partie du cours

Deep Reinforcement Learning in Python

Afficher le cours

Instructions

  • In .push(), initialize the transition's priority to the maximum priority in the buffer (or 1 if the buffer is empty).
  • In .update_priorities(), set the priority to the absolute value of the corresponding TD error; add self.epsilon to cover edge cases.
  • In .increase_beta(), increment beta by self.beta_increment; ensure beta never exceeds 1.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Modifier et exécuter le code