DQN con repetición priorizada de experiencias
En este ejercicio, introducirás la Repetición de Experiencias Priorizadas (PER) para mejorar el algoritmo DQN. PER pretende optimizar el lote de transiciones seleccionadas para actualizar la red en cada paso.
Como referencia, los nombres de los métodos que has declarado para PrioritizedReplayBuffer
son:
push()
(para empujar las transiciones a la memoria intermedia)sample()
(para muestrear un lote de transiciones de la memoria intermedia)increase_beta()
(para aumentar el muestreo de importancia)update_priorities()
(para actualizar las prioridades muestreadas)
La función describe_episode()
se utiliza de nuevo para describir cada episodio.
Este ejercicio forma parte del curso
Aprendizaje profundo por refuerzo en Python
Instrucciones de ejercicio
- Instanciar un búfer de Reproducción de Experiencias Priorizadas con una capacidad de 10000 transiciones.
- Aumenta la influencia del muestreo de importancia a lo largo del tiempo actualizando el parámetro
beta
. - Actualiza la prioridad de las experiencias muestreadas en función de su último error TD.
Ejercicio interactivo práctico
Pruebe este ejercicio completando este código de muestra.
# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)
for episode in range(5):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
# Increase the replay buffer's beta parameter
replay_buffer.____
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_network(next_states).amax(1)
target_q_values = rewards + gamma * next_q_values * (1-dones)
td_errors = target_q_values - q_values
# Update the replay buffer priorities for that batch
replay_buffer.____(____, ____)
loss = torch.sum(weights * (q_values - target_q_values) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)