Aan de slagGa gratis aan de slag

DQN with prioritized experience replay

In this exercise, you will introduce Prioritized Experience Replay (PER) to improve the DQN algorithm. PER aims to optimize the batch of transitions selected to update the network at each step.

For reference, the method names you have declared for PrioritizedReplayBuffer are:

  • push() (to push transitions to the buffer)
  • sample() (to sample a batch of transitions from the buffer)
  • increase_beta() (to increase importance sampling)
  • update_priorities() (to update the sampled priorities)

The describe_episode() function is used again to describe each episode.

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Cursus bekijken

Oefeninstructies

  • Instantiate a Prioritized Experience Replay buffer with a capacity of 10000 transitions.
  • Increase the influence of importance sampling over time by updating the beta parameter.
  • Update the priority of the sampled experiences based on their latest TD error.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Code bewerken en uitvoeren