BaşlayınÜcretsiz Başlayın

Sampling from the PER buffer

Before you can use the Prioritized Experience Buffer class to train your agent, you still need to implement the .sample() method. This method takes as argument the size of the sample you want to draw, and returns the sampled transitions as tensors, along with their indices in the memory buffer and their importance weight.

A buffer with capacity 10 has been pre-loaded in your environment for you to sample from.

Bu egzersiz

Deep Reinforcement Learning in Python

kursunun bir parçasıdır
Kursu Görüntüle

Egzersiz talimatları

  • Calculate the sampling probability associated with each transition.
  • Draw the indices corresponding to the transitions in the sample; np.random.choice(a, s, p=p) takes a sample of size s with replacement from the array a, based on probability array p.
  • Calculate the importance weight associated with each transition.

Uygulamalı interaktif egzersiz

Bu örnek kodu tamamlayarak bu egzersizi bitirin.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Kodu Düzenle ve Çalıştır