Experience replay buffer
Ora creerai la struttura dati che supporta l’Experience Replay, che permetterà al tuo agente di imparare in modo molto più efficiente.
Questo replay buffer deve supportare due operazioni:
- Memorizzare le esperienze nella propria memoria per campionamenti futuri.
- "Riprodurre" (replay) un batch campionato casualmente di esperienze passate dalla propria memoria.
Poiché i dati campionati dal replay buffer verranno usati come input per una rete neurale, il buffer dovrebbe restituire Tensors di torch per comodità.
I moduli torch e random e la classe deque sono stati importati nel tuo ambiente di esercizio.
Questo esercizio fa parte del corso
Deep Reinforcement Learning in Python
Istruzioni dell'esercizio
- Completa il metodo
push()diReplayBufferaggiungendoexperience_tuplealla memoria del buffer. - Nel metodo
sample(), estrai un campione casuale di dimensionebatch_sizedaself.memory. - Sempre in
sample(), il campione viene inizialmente estratto come lista di tuple; assicurati che venga trasformato in una tupla di liste. - Trasforma
actions_tensornella forma(batch_size, 1)invece di(batch_size).
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor