Aan de slagGa gratis aan de slag

Steekproeven uit de PER-buffer

Voordat je de Prioritized Experience Buffer-klasse kunt gebruiken om je agent te trainen, moet je nog de methode .sample() implementeren. Deze methode krijgt als argument de grootte van de steekproef die je wilt trekken en retourneert de getrokken transities als tensors, samen met hun indexen in de geheugenbuffer en hun belangrijkheidsgewicht.

Een buffer met capaciteit 10 is alvast in je omgeving geladen, zodat je daaruit kunt steekproeven.

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Cursus bekijken

Oefeninstructies

  • Bereken de steekproefkans die hoort bij elke transitie.
  • Trek de indexen die overeenkomen met de transities in de steekproef; np.random.choice(a, s, p=p) neemt een steekproef van grootte s met terugleggen uit de array a, op basis van kansarray p.
  • Bereken het belangrijkheidsgewicht dat hoort bij elke transitie.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Code bewerken en uitvoeren