Aan de slagGa gratis aan de slag

DQN met prioritaire experience replay

In deze oefening voeg je Prioritized Experience Replay (PER) toe om het DQN-algoritme te verbeteren. PER probeert de batch transities die bij elke stap wordt geselecteerd om het netwerk te updaten, te optimaliseren.

Ter referentie, de methoden die je hebt gedefinieerd voor PrioritizedReplayBuffer zijn:

  • push() (om transities in de buffer te plaatsen)
  • sample() (om een batch transities uit de buffer te halen)
  • increase_beta() (om importance sampling te vergroten)
  • update_priorities() (om de gesamplede prioriteiten bij te werken)

De functie describe_episode() wordt opnieuw gebruikt om elke episode te beschrijven.

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Cursus bekijken

Oefeninstructies

  • Maak een Prioritized Experience Replay-buffer met een capaciteit van 10000 transities.
  • Vergroot de invloed van importance sampling in de tijd door de parameter beta te updaten.
  • Werk de prioriteit van de gesamplede ervaringen bij op basis van hun nieuwste TD-fout.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Instantiate a Prioritized Replay Buffer with capacity 10000
replay_buffer = ____(____)

for episode in range(5):
    state, info = env.reset()
    done = False   
    step = 0
    episode_reward = 0    
    # Increase the replay buffer's beta parameter
    replay_buffer.____
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            with torch.no_grad():
                next_q_values = target_network(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)            
            td_errors = target_q_values - q_values
            # Update the replay buffer priorities for that batch
            replay_buffer.____(____, ____)
            loss = torch.sum(weights * (q_values - target_q_values) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_target_network(target_network, online_network, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Code bewerken en uitvoeren