Experience replay buffer
You will now create the data structure to support Experience Replay, which will enable your agent to learn much more efficiently.
This replay buffer should support two operations:
- Storing experiences in its memory for future sampling.
- "Replaying" a randomly sampled batch of past experiences from its memory.
As the data sampled from the replay buffer will be used to feed into a neural network, the buffer should return torch
Tensors for convenience.
The torch
and random
modules and the deque
class have been imported into your exercise environment.
Cet exercice fait partie du cours
Deep Reinforcement Learning in Python
Instructions
- Complete the
push()
method ofReplayBuffer
by appendingexperience_tuple
to the buffer memory. - In the
sample()
method, draw a random sample of sizebatch_size
fromself.memory
. - Again in
sample()
, the sample is initially drawn as a list of tuples; ensure that it is transformed into a tuple of lists. - Transform
actions_tensor
into shape(batch_size, 1)
instead of(batch_size)
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor