CommencerCommencer gratuitement

Tampon de relecture d'expérience prioritaire

Vous allez présenter la classe PrioritizedExperienceReplay, une structure de données que vous utiliserez plus tard pour implémenter le DQN avec la méthode Prioritized Experience Replay.

PrioritizedExperienceReplay est une version améliorée de la classe ExperienceReplay que vous avez utilisée jusqu'à présent pour former vos agents DQN. Un tampon de relecture d'expériences hiérarchisé garantit que les transitions échantillonnées à partir de celui-ci sont plus utiles à l'agent pour apprendre qu'avec un échantillonnage uniforme.

Pour l'instant, veuillez implémenter les méthodes .__init__(), .push(), .update_priorities(), .increase_beta() et .__len__(). La dernière méthode, .sample(), sera abordée dans le prochain exercice.

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Dans .push(), initialisez la priorité de la transition à la priorité maximale dans le tampon (ou 1 si le tampon est vide).
  • Dans .update_priorities(), définissez la priorité sur la valeur absolue de l'erreur TD correspondante ; ajoutez self.epsilon pour couvrir les cas limites.
  • Dans .increase_beta(), augmentez la valeur de beta de self.beta_increment; assurez-vous que beta ne dépasse jamais 1.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Modifier et exécuter le code