Échantillonnage à partir du tampon PER
Avant de pouvoir utiliser la classe PrioritizedExperienceBuffer pour former votre agent, vous devez implémenter la méthode .sample(). Cette méthode prend comme argument la taille de l'échantillon que vous souhaitez extraire et renvoie les transitions échantillonnées sous forme d'tensors, ainsi que leurs indices dans la mémoire tampon et leur poids d'importance.
Un tampon d'une capacité de 10 a été préchargé dans votre environnement afin que vous puissiez effectuer des échantillonnages.
Cet exercice fait partie du cours
Apprentissage par renforcement profond en Python
Instructions
- Calculez la probabilité d'échantillonnage associée à chaque transition.
- Dessinez les indices correspondant aux transitions dans l'échantillon ;
np.random.choice(a, s, p=p)prend un échantillon de taillesavec remplacement à partir du tableaua, sur la base du tableau de probabilitésp. - Calculez le poids d'importance associé à chaque transition.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def sample(self, batch_size):
priorities = np.array(self.priorities)
# Calculate the sampling probabilities
probabilities = ____ / np.sum(____)
# Draw the indices for the sample
indices = np.random.choice(____)
# Calculate the importance weights
weights = (1 / (len(self.memory) * ____)) ** ____
weights /= np.max(weights)
states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
weights = [weights[idx] for idx in indices]
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
weights_tensor = torch.tensor(weights, dtype=torch.float32)
actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
dones_tensor, indices, weights_tensor)
PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))