Probenahme aus dem Puffer PER
Bevor du die Klasse Prioritized Experience Buffer verwenden kannst, um deinen Agenten zu trainieren, musst du noch die Methode .sample()
implementieren. Diese Methode nimmt als Argument die Größe der Stichprobe, die du ziehen willst, und gibt die gesampelten Übergänge als tensors
zurück, zusammen mit ihren Indizes im Speicherpuffer und ihrem Wichtigkeitsgewicht.
In deiner Umgebung wurde ein Puffer mit einer Kapazität von 10 geladen, aus dem du Proben nehmen kannst.
Diese Übung ist Teil des Kurses
Deep Reinforcement Learning in Python
Anleitung zur Übung
- Berechne die Stichprobenwahrscheinlichkeit, die mit jedem Übergang verbunden ist.
- Ziehe die Indizes, die den Übergängen in der Stichprobe entsprechen;
np.random.choice(a, s, p=p)
zieht eine Stichprobe der Größes
mit Ersatz aus dem Arraya
, basierend auf dem Wahrscheinlichkeitsarrayp
. - Berechne die Wichtigkeit, die jedem Übergang zugeordnet ist.
Interaktive Übung zum Anfassen
Probieren Sie diese Übung aus, indem Sie diesen Beispielcode ausführen.
def sample(self, batch_size):
priorities = np.array(self.priorities)
# Calculate the sampling probabilities
probabilities = ____ / np.sum(____)
# Draw the indices for the sample
indices = np.random.choice(____)
# Calculate the importance weights
weights = (1 / (len(self.memory) * ____)) ** ____
weights /= np.max(weights)
states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
weights = [weights[idx] for idx in indices]
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
weights_tensor = torch.tensor(weights, dtype=torch.float32)
actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
dones_tensor, indices, weights_tensor)
PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))