CommencerCommencer gratuitement

Échantillonnage à partir du tampon PER

Avant de pouvoir utiliser la classe PrioritizedExperienceBuffer pour former votre agent, vous devez implémenter la méthode .sample(). Cette méthode prend comme argument la taille de l'échantillon que vous souhaitez extraire et renvoie les transitions échantillonnées sous forme d'tensors, ainsi que leurs indices dans la mémoire tampon et leur poids d'importance.

Un tampon d'une capacité de 10 a été préchargé dans votre environnement afin que vous puissiez effectuer des échantillonnages.

Cet exercice fait partie du cours

Apprentissage par renforcement profond en Python

Afficher le cours

Instructions

  • Calculez la probabilité d'échantillonnage associée à chaque transition.
  • Dessinez les indices correspondant aux transitions dans l'échantillon ; np.random.choice(a, s, p=p) prend un échantillon de taille s avec remplacement à partir du tableau a, sur la base du tableau de probabilités p.
  • Calculez le poids d'importance associé à chaque transition.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Modifier et exécuter le code