Buffer de reprodução de experiência
Agora você criará a estrutura de dados para dar suporte ao Experience Replay, o que permitirá que seu agente aprenda com muito mais eficiência.
Esse buffer de reprodução deve suportar duas operações:
- Armazenamento de experiências em sua memória para amostragem futura.
- "Repetição" de um lote de experiências passadas, coletadas aleatoriamente de sua memória.
Como os dados amostrados no buffer de reprodução serão usados para alimentar uma rede neural, o buffer deve retornar torch
Tensors por conveniência.
Os módulos torch
e random
e a classe deque
foram importados para seu ambiente de exercícios.
Este exercício faz parte do curso
Aprendizado por reforço profundo em Python
Instruções do exercício
- Conclua o método
push()
deReplayBuffer
anexandoexperience_tuple
à memória intermediária. - No método
sample()
, extraia uma amostra aleatória de tamanhobatch_size
deself.memory
. - Novamente em
sample()
, a amostra é inicialmente desenhada como uma lista de tuplas; certifique-se de que ela seja transformada em uma tupla de listas. - Transforme
actions_tensor
na forma(batch_size, 1)
em vez de(batch_size)
.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor