A função objetiva substituta recortada
Implemente a função calculate_loss()
para PPO. Isso requer a codificação da principal inovação do site PPO - a função de perda substituta recortada. Isso ajuda a restringir a atualização da política para evitar que ela se afaste muito da política anterior em cada etapa.
A fórmula para o objetivo substituto cortado é
Seu ambiente tem o hiperparâmetro de recorte epsilon
definido como 0,2.
Este exercício faz parte do curso
Aprendizado por reforço profundo em Python
Instruções de exercício
- Obtenha as razões de probabilidade entre
\pi_\theta
e\pi_{\theta_{old}}
(versões com e sem recorte). - Calcule os objetivos substitutos (versões não cortadas e cortadas).
- Calcule o objetivo substituto recortado do PPO.
- Calcule a perda do ator.
Exercício interativo prático
Experimente este exercício preenchendo este código de exemplo.
def calculate_losses(critic_network, action_log_prob, action_log_prob_old,
reward, state, next_state, done):
value = critic_network(state)
next_value = critic_network(next_state)
td_target = (reward + gamma * next_value * (1-done))
td_error = td_target - value
# Obtain the probability ratios
____, ____ = calculate_ratios(action_log_prob, action_log_prob_old, epsilon=.2)
# Calculate the surrogate objectives
surr1 = ratio * ____.____()
surr2 = clipped_ratio * ____.____()
# Calculate the clipped surrogate objective
objective = torch.min(____, ____)
# Calculate the actor loss
actor_loss = ____
critic_loss = td_error ** 2
return actor_loss, critic_loss
actor_loss, critic_loss = calculate_losses(critic_network, action_log_prob, action_log_prob_old,
reward, state, next_state, done)
print(actor_loss, critic_loss)