ComeçarComece gratuitamente

Metas Q fixas

Você está se preparando para treinar seu Lunar Lander com alvos Q fixos. Como pré-requisito, você precisa instanciar a rede on-line (que escolhe a ação) e a rede de destino (usada para o cálculo do TD-target).

Você também precisa implementar uma função update_target_network que possa ser usada em cada etapa de treinamento. A rede de destino não é atualizada pela descida do gradiente; em vez disso, o site update_target_network empurra seus pesos para a rede Q em uma pequena quantidade, garantindo que ela permaneça bastante estável ao longo do tempo.

Observe que, somente para este exercício, você usará uma rede muito pequena para que possamos imprimir e inspecionar facilmente seu dicionário de estado. Ele tem apenas uma camada oculta de tamanho 2; seu espaço de ação e espaço de estado também são de dimensão 2.

A função print_state_dict() está disponível em seu ambiente para imprimir o dict de estado.

Este exercício faz parte do curso

Aprendizado por reforço profundo em Python

Ver Curso

Instruções de exercício

  • Obtenha o endereço .state_dict() para as redes alvo e on-line.
  • Atualize o ditado de estado para a rede de destino fazendo a média ponderada entre os parâmetros da rede on-line e da rede de destino, usando tau como peso para a rede on-line.
  • Carregue o ditado de estado atualizado de volta na rede de destino.

Exercício interativo prático

Experimente este exercício preenchendo este código de exemplo.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Editar e executar código