ComenzarEmpieza gratis

DQN con repetición priorizada de experiencias

En este ejercicio, introducirás la Repetición de Experiencias Priorizadas (PER) para mejorar el algoritmo DQN. PER pretende optimizar el lote de transiciones seleccionadas para actualizar la red en cada paso.

Como referencia, los nombres de los métodos que has declarado para PrioritizedReplayBuffer son:

  • push() (para empujar las transiciones a la memoria intermedia)
  • sample() (para muestrear un lote de transiciones de la memoria intermedia)
  • increase_beta() (para aumentar el muestreo de importancia)
  • update_priorities() (para actualizar las prioridades muestreadas)

La función describe_episode() se utiliza de nuevo para describir cada episodio.

Este ejercicio forma parte del curso

Aprendizaje profundo por refuerzo en Python

Ver curso

Instrucciones de ejercicio

  • Instanciar un búfer de Reproducción de Experiencias Priorizadas con una capacidad de 10000 transiciones.
  • Aumenta la influencia del muestreo de importancia a lo largo del tiempo actualizando el parámetro beta.
  • Actualiza la prioridad de las experiencias muestreadas en función de su último error TD.

Ejercicio interactivo práctico

Pruebe este ejercicio completando este código de muestra.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Editar y ejecutar código