ComeçarComece gratuitamente

Amostragem do buffer PER

Antes de poder usar a classe Prioritized Experience Buffer para treinar seu agente, você ainda precisa implementar o método .sample(). Esse método recebe como argumento o tamanho da amostra que você deseja desenhar e retorna as transições amostradas como tensors, juntamente com seus índices no buffer de memória e seu peso de importância.

Um buffer com capacidade 10 foi pré-carregado no seu ambiente para que você faça a amostragem.

Este exercício faz parte do curso

Aprendizado por reforço profundo em Python

Ver Curso

Instruções de exercício

  • Calcule a probabilidade de amostragem associada a cada transição.
  • Desenhe os índices correspondentes às transições na amostra; np.random.choice(a, s, p=p) obtém uma amostra de tamanho s com substituição da matriz a, com base na matriz de probabilidade p.
  • Calcule o peso de importância associado a cada transição.

Exercício interativo prático

Experimente este exercício preenchendo este código de exemplo.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Editar e executar código