Barebone DQN Verlustfunktion
Da die Funktion select_action()
nun fertig ist, fehlt dir nur noch ein letzter Schritt, um deinen Agenten zu schulen: Du implementierst nun calculate_loss()
.
calculate_loss()
gibt den Netzwerkverlust für einen bestimmten Schritt der Episode zurück.
Zum Vergleich: Der Verlust wird wie folgt angegeben:
Die folgenden Beispieldaten wurden für die Übung geladen:
state = torch.rand(8)
next_state = torch.rand(8)
action = select_action(q_network, state)
reward = 1
gamma = .99
done = False
Diese Übung ist Teil des Kurses
Deep Reinforcement Learning in Python
Anleitung zur Übung
- Erfahre den aktuellen Q-Wert.
- Erhalte den nächsten Q-Wert für den Zustand.
- Berechne den Q-Zielwert, oder TD-target.
- Berechne die Verlustfunktion, d.h. den quadrierten Bellman-Fehler.
Interaktive Übung zum Anfassen
Probieren Sie diese Übung aus, indem Sie diesen Beispielcode ausführen.
def calculate_loss(q_network, state, action, next_state, reward, done):
q_values = q_network(state)
print(f'Q-values: {q_values}')
# Obtain the current state Q-value
current_state_q_value = q_values[____]
print(f'Current state Q-value: {current_state_q_value:.2f}')
# Obtain the next state Q-value
next_state_q_value = q_network(next_state).____
print(f'Next state Q-value: {next_state_q_value:.2f}')
# Calculate the target Q-value
target_q_value = ____ + gamma * ____ * (1-done)
print(f'Target Q-value: {target_q_value:.2f}')
# Obtain the loss
loss = nn.MSELoss()(____, ____)
print(f'Loss: {loss:.2f}')
return loss
calculate_loss(q_network, state, action, next_state, reward, done)