Experience replay buffer
You will now create the data structure to support Experience Replay, which will enable your agent to learn much more efficiently.
This replay buffer should support two operations:
- Storing experiences in its memory for future sampling.
- "Replaying" a randomly sampled batch of past experiences from its memory.
As the data sampled from the replay buffer will be used to feed into a neural network, the buffer should return torch
Tensors for convenience.
The torch
and random
modules and the deque
class have been imported into your exercise environment.
This exercise is part of the course
Deep Reinforcement Learning in Python
Exercise instructions
- Complete the
push()
method ofReplayBuffer
by appendingexperience_tuple
to the buffer memory. - In the
sample()
method, draw a random sample of sizebatch_size
fromself.memory
. - Again in
sample()
, the sample is initially drawn as a list of tuples; ensure that it is transformed into a tuple of lists. - Transform
actions_tensor
into shape(batch_size, 1)
instead of(batch_size)
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor