Experience replay-buffer
Je gaat nu de datastructuur bouwen voor Experience Replay, zodat je agent veel efficiënter kan leren.
Deze replay-buffer moet twee bewerkingen ondersteunen:
- Ervaringen opslaan in het geheugen voor later samplen.
- Een willekeurig geselecteerde batch van eerdere ervaringen uit het geheugen "replayen".
Omdat de data uit de replay-buffer aan een neuraal netwerk wordt gevoerd, moet de buffer voor het gemak torch-Tensors teruggeven.
De modules torch en random en de klasse deque zijn al in je oefenomgeving geïmporteerd.
Deze oefening maakt deel uit van de cursus
Deep Reinforcement Learning in Python
Oefeninstructies
- Maak de
push()-methode vanReplayBufferaf doorexperience_tupletoe te voegen aan het buffermemory. - Trek in de
sample()-methode een willekeurige sample van groottebatch_sizeuitself.memory. - Zorg er in
sample()ook voor dat de sample, die eerst als een lijst met tuples wordt getrokken, wordt omgezet naar een tuple met lijsten. - Zet
actions_tensorom naar vorm(batch_size, 1)in plaats van(batch_size).
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor