Metas Q fixas
Você está se preparando para treinar seu Lunar Lander com alvos Q fixos. Como pré-requisito, você precisa instanciar a rede on-line (que escolhe a ação) e a rede de destino (usada para o cálculo do TD-target).
Você também precisa implementar uma função update_target_network
que possa ser usada em cada etapa de treinamento. A rede de destino não é atualizada pela descida do gradiente; em vez disso, o site update_target_network
empurra seus pesos para a rede Q em uma pequena quantidade, garantindo que ela permaneça bastante estável ao longo do tempo.
Observe que, somente para este exercício, você usará uma rede muito pequena para que possamos imprimir e inspecionar facilmente seu dicionário de estado. Ele tem apenas uma camada oculta de tamanho 2; seu espaço de ação e espaço de estado também são de dimensão 2.
A função print_state_dict()
está disponível em seu ambiente para imprimir o dict de estado.
Este exercício faz parte do curso
Aprendizado por reforço profundo em Python
Instruções de exercício
- Obtenha o endereço
.state_dict()
para as redes alvo e on-line. - Atualize o ditado de estado para a rede de destino fazendo a média ponderada entre os parâmetros da rede on-line e da rede de destino, usando
tau
como peso para a rede on-line. - Carregue o ditado de estado atualizado de volta na rede de destino.
Exercício interativo prático
Experimente este exercício preenchendo este código de exemplo.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))