ComeçarComece gratuitamente

Implementação do algoritmo DQN completo

Finalmente chegou a hora! Todos os pré-requisitos estão completos; agora você implementará o algoritmo DQN completo e o usará para treinar um agente do Lunar Lander. Isso significa que o seu algoritmo usará não apenas a repetição de experiência, mas também o Epsilon-Greediness decaído e os Q-Targets fixos.

A função select_action() que implementa o Decayed Epsilon Greediness está disponível para você usar, assim como a função update_target_network() do último exercício. Tudo o que resta a fazer é ajustar essas funções no loop de treinamento DQN e garantir que você esteja usando corretamente a Target Network nos cálculos de perda.

Você precisa manter um novo contador de etapas, total_steps, para decair o valor de \(\varepsilon\) ao longo do tempo. Essa variável é inicializada para você com o valor 0.

Este exercício faz parte do curso

Aprendizado por reforço profundo em Python

Ver Curso

Instruções de exercício

  • Use select_action() para implementar o Decayed Epsilon Greediness e selecionar a ação do agente; você precisará usar total_steps, o total em execução nos episódios.
  • Antes de calcular o alvo TD, desative o rastreamento de gradiente.
  • Depois de obter o próximo estado, obtenha os valores Q do próximo estado.
  • Atualize a rede de destino no final de cada etapa.

Exercício interativo prático

Experimente este exercício preenchendo este código de exemplo.

for episode in range(10):
    state, info = env.reset()
    done = False
    step = 0
    episode_reward = 0
    while not done:
        step += 1
        total_steps += 1
        q_values = online_network(state)
        # Select the action with epsilon greediness
        action = ____(____, ____, start=.9, end=.05, decay=1000)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        replay_buffer.push(state, action, reward, next_state, done)        
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones = replay_buffer.sample(64)
            q_values = online_network(states).gather(1, actions).squeeze(1)
            # Ensure gradients are not tracked
            with ____:
                # Obtain the next state Q-values
                next_q_values = ____(next_states).amax(1)
                target_q_values = rewards + gamma * next_q_values * (1-dones)
            loss = nn.MSELoss()(q_values, target_q_values)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()   
            # Update the target network weights
            ____(____, ____, tau=.005)
        state = next_state
        episode_reward += reward    
    describe_episode(episode, reward, episode_reward, step)
Editar e executar código