Aan de slagGa gratis aan de slag

Barebone DQN-verliesfunctie

Nu de functie select_action() klaar is, ben je nog maar één laatste stap verwijderd van het trainen van je agent: je gaat nu calculate_loss() implementeren.

De functie calculate_loss() geeft het netwerkverlies terug voor een willekeurige stap in de episode.

Ter referentie, het verlies is gegeven door:

De volgende voorbeeldgegevens zijn in de oefening geladen:

state = torch.rand(8)
next_state = torch.rand(8)
action = select_action(q_network, state)
reward = 1
gamma = .99
done = False

Deze oefening maakt deel uit van de cursus

Deep Reinforcement Learning in Python

Cursus bekijken

Oefeninstructies

  • Bepaal de Q-waarde van de huidige toestand.
  • Bepaal de Q-waarde van de volgende toestand.
  • Bereken de doel-Q-waarde, oftewel de TD-target.
  • Bereken de verliesfunctie, d.w.z. de squared Bellman Error.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

def calculate_loss(q_network, state, action, next_state, reward, done):
    q_values = q_network(state)
    print(f'Q-values: {q_values}')
    # Obtain the current state Q-value
    current_state_q_value = q_values[____]
    print(f'Current state Q-value: {current_state_q_value:.2f}')
    # Obtain the next state Q-value
    next_state_q_value = q_network(next_state).____    
    print(f'Next state Q-value: {next_state_q_value:.2f}')
    # Calculate the target Q-value
    target_q_value = ____ + gamma * ____ * (1-done)
    print(f'Target Q-value: {target_q_value:.2f}')
    # Obtain the loss
    loss = nn.MSELoss()(____, ____)
    print(f'Loss: {loss:.2f}')
    return loss

calculate_loss(q_network, state, action, next_state, reward, done)
Code bewerken en uitvoeren