Tampon de relecture d'expérience prioritaire
Vous allez présenter la classe PrioritizedExperienceReplay, une structure de données que vous utiliserez plus tard pour implémenter le DQN avec la méthode Prioritized Experience Replay.
PrioritizedExperienceReplay est une version améliorée de la classe ExperienceReplay que vous avez utilisée jusqu'à présent pour former vos agents DQN. Un tampon de relecture d'expériences hiérarchisé garantit que les transitions échantillonnées à partir de celui-ci sont plus utiles à l'agent pour apprendre qu'avec un échantillonnage uniforme.
Pour l'instant, veuillez implémenter les méthodes .__init__(), .push(), .update_priorities(), .increase_beta() et .__len__(). La dernière méthode, .sample(), sera abordée dans le prochain exercice.
Cet exercice fait partie du cours
Apprentissage par renforcement profond en Python
Instructions
- Dans
.push(), initialisez la priorité de la transition à la priorité maximale dans le tampon (ou 1 si le tampon est vide). - Dans
.update_priorities(), définissez la priorité sur la valeur absolue de l'erreur TD correspondante ; ajoutezself.epsilonpour couvrir les cas limites. - Dans
.increase_beta(), augmentez la valeur de beta deself.beta_increment; assurez-vous quebetane dépasse jamais 1.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)