ComeçarComece gratuitamente

A função objetiva substituta recortada

Implemente a função calculate_loss() para PPO. Isso requer a codificação da principal inovação do site PPO - a função de perda substituta recortada. Isso ajuda a restringir a atualização da política para evitar que ela se afaste muito da política anterior em cada etapa.

A fórmula para o objetivo substituto cortado é

Seu ambiente tem o hiperparâmetro de recorte epsilon definido como 0,2.

Este exercício faz parte do curso

Aprendizado por reforço profundo em Python

Ver Curso

Instruções de exercício

  • Obtenha as razões de probabilidade entre \pi_\theta e \pi_{\theta_{old}} (versões com e sem recorte).
  • Calcule os objetivos substitutos (versões não cortadas e cortadas).
  • Calcule o objetivo substituto recortado do PPO.
  • Calcule a perda do ator.

Exercício interativo prático

Experimente este exercício preenchendo este código de exemplo.

def calculate_losses(critic_network, action_log_prob, action_log_prob_old,
                     reward, state, next_state, done):
    value = critic_network(state)
    next_value = critic_network(next_state)
    td_target = (reward + gamma * next_value * (1-done))
    td_error = td_target - value
    # Obtain the probability ratios
    ____, ____ = calculate_ratios(action_log_prob, action_log_prob_old, epsilon=.2)
    # Calculate the surrogate objectives
    surr1 = ratio * ____.____()
    surr2 = clipped_ratio * ____.____()    
    # Calculate the clipped surrogate objective
    objective = torch.min(____, ____)
    # Calculate the actor loss
    actor_loss = ____
    critic_loss = td_error ** 2
    return actor_loss, critic_loss
  
actor_loss, critic_loss = calculate_losses(critic_network, action_log_prob, action_log_prob_old,
                                           reward, state, next_state, done)
print(actor_loss, critic_loss)
Editar e executar código