Experience replay buffer
You will now create the data structure to support Experience Replay, which will enable your agent to learn much more efficiently.
This replay buffer should support two operations:
- Storing experiences in its memory for future sampling.
- "Replaying" a randomly sampled batch of past experiences from its memory.
As the data sampled from the replay buffer will be used to feed into a neural network, the buffer should return torch Tensors for convenience.
The torch and random modules and the deque class have been imported into your exercise environment.
Cet exercice fait partie du cours
Deep Reinforcement Learning in Python
Instructions
- Complete the
push()method ofReplayBufferby appendingexperience_tupleto the buffer memory. - In the
sample()method, draw a random sample of sizebatch_sizefromself.memory. - Again in
sample(), the sample is initially drawn as a list of tuples; ensure that it is transformed into a tuple of lists. - Transform
actions_tensorinto shape(batch_size, 1)instead of(batch_size).
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor