Prioritized experience replay buffer
Introdurrai la classe PrioritizedExperienceReplay, una struttura dati che userai più avanti per implementare DQN con Prioritized Experience Replay.
PrioritizedExperienceReplay è un perfezionamento della classe ExperienceReplay che hai usato finora per addestrare i tuoi agenti DQN. Un prioritized experience replay buffer fa sì che le transizioni campionate siano, in media, più utili per l'apprendimento dell'agente rispetto a un campionamento uniforme.
Per ora, implementa i metodi .__init__(), .push(), .update_priorities(), .increase_beta() e .__len__(). Il metodo finale, .sample(), sarà l'obiettivo del prossimo esercizio.
Questo esercizio fa parte del corso
Deep Reinforcement Learning in Python
Istruzioni dell'esercizio
- In
.push(), inizializza la priorità della transizione alla priorità massima nel buffer (oppure a 1 se il buffer è vuoto). - In
.update_priorities(), imposta la priorità al valore assoluto del corrispondente TD error; aggiungiself.epsilonper coprire i casi limite. - In
.increase_beta(), incrementa beta diself.beta_increment; assicurati chebetanon superi mai 1.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)