Buffer de reprodução de experiência
Agora você criará a estrutura de dados para dar suporte ao Experience Replay, o que permitirá que seu agente aprenda com muito mais eficiência.
Esse buffer de reprodução deve suportar duas operações:
- Armazenamento de experiências em sua memória para amostragem futura.
- "Repetição" de um lote de experiências passadas, coletadas aleatoriamente de sua memória.
Como os dados amostrados no buffer de reprodução serão usados para alimentar uma rede neural, o buffer deve retornar torch Tensors por conveniência.
Os módulos torch e random e a classe deque foram importados para seu ambiente de exercícios.
Este exercício faz parte do curso
Aprendizado por reforço profundo em Python
Instruções do exercício
- Conclua o método
push()deReplayBufferanexandoexperience_tupleà memória intermediária. - No método
sample(), extraia uma amostra aleatória de tamanhobatch_sizedeself.memory. - Novamente em
sample(), a amostra é inicialmente desenhada como uma lista de tuplas; certifique-se de que ela seja transformada em uma tupla de listas. - Transforme
actions_tensorna forma(batch_size, 1)em vez de(batch_size).
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor