Campionare dal buffer PER
Prima di poter usare la classe Prioritized Experience Buffer per addestrare il tuo agente, devi ancora implementare il metodo .sample(). Questo metodo prende come argomento la dimensione del campione che vuoi estrarre e restituisce le transizioni campionate come tensors, insieme ai loro indici nel buffer di memoria e ai relativi pesi d'importanza.
Un buffer con capacità 10 è già stato caricato nel tuo ambiente per consentirti di effettuare i campionamenti.
Questo esercizio fa parte del corso
Deep Reinforcement Learning in Python
Istruzioni dell'esercizio
- Calcola la probabilità di campionamento associata a ciascuna transizione.
- Estrai gli indici corrispondenti alle transizioni nel campione;
np.random.choice(a, s, p=p)estrae un campione di dimensionescon reimmissione dall'arraya, basandosi sull'array di probabilitàp. - Calcola il peso d'importanza associato a ciascuna transizione.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
def sample(self, batch_size):
priorities = np.array(self.priorities)
# Calculate the sampling probabilities
probabilities = ____ / np.sum(____)
# Draw the indices for the sample
indices = np.random.choice(____)
# Calculate the importance weights
weights = (1 / (len(self.memory) * ____)) ** ____
weights /= np.max(weights)
states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
weights = [weights[idx] for idx in indices]
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
weights_tensor = torch.tensor(weights, dtype=torch.float32)
actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
dones_tensor, indices, weights_tensor)
PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))