Búfer de repetición de experiencia
Ahora crearás la estructura de datos para soportar la Reproducción de Experiencias, lo que permitirá a tu agente aprender de forma mucho más eficaz.
Este búfer de repetición debe admitir dos operaciones:
- Almacena experiencias en su memoria para muestrearlas en el futuro.
- "Reproduciendo" de su memoria un lote aleatorio de experiencias pasadas.
Como los datos muestreados del búfer de repetición se utilizarán para alimentar una red neuronal, el búfer debe devolver torch
Tensores por comodidad.
Los módulos torch
y random
y la clase deque
se han importado a tu entorno de ejercicio.
Este ejercicio forma parte del curso
Aprendizaje profundo por refuerzo en Python
Instrucciones del ejercicio
- Completa el método
push()
deReplayBuffer
añadiendoexperience_tuple
a la memoria intermedia. - En el método
sample()
, extrae una muestra aleatoria de tamañobatch_size
deself.memory
. - De nuevo en
sample()
, la muestra se extrae inicialmente como una lista de tuplas; asegúrate de que se transforma en una tupla de listas. - Transforma
actions_tensor
en la forma(batch_size, 1)
en lugar de(batch_size)
.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor