Aan de slagBegin gratis

Steekproeven uit de PER-buffer

Voordat je de Prioritized Experience Buffer-klasse kunt gebruiken om je agent te trainen, moet je nog de methode .sample() implementeren. Deze methode krijgt als argument de grootte van de steekproef die je wilt trekken en retourneert de getrokken transities als tensors, samen met hun indexen in de geheugenbuffer en hun belangrijkheidsgewicht.

Een buffer met capaciteit 10 is alvast in je omgeving geladen, zodat je daaruit kunt steekproeven.

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Bekijk cursus

Oefeninstructies

  • Bereken de steekproefkans die hoort bij elke transitie.
  • Trek de indexen die overeenkomen met de transities in de steekproef; np.random.choice(a, s, p=p) neemt een steekproef van grootte s met terugleggen uit de array a, op basis van kansarray p.
  • Bereken het belangrijkheidsgewicht dat hoort bij elke transitie.

Interactieve oefening met praktijkervaring

Probeer deze oefening door deze voorbeeldcode aan te vullen.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Code bewerken en uitvoeren