Entrenamiento del doble DQN
Ahora modificarás tu código para DQN para implementar el doble DQN.
El doble DQN sólo requiere un ajuste mínimo del algoritmo DQN, pero contribuye en gran medida a resolver el problema de la sobreestimación del valor Q y a menudo funciona mejor que DQN.
Este ejercicio forma parte del curso
Aprendizaje profundo por refuerzo en Python
Instrucciones de ejercicio
- Calcula las acciones siguientes para el cálculo del objetivo Q utilizando la página
online_network()
, asegurándote de obtener la acción y la forma correctas. - Estima los valores Q a estas acciones con el
target_network()
, asegurándote de nuevo de obtener los valores y la forma correctos.
Ejercicio interactivo práctico
Pruebe este ejercicio completando este código de muestra.
for episode in range(10):
state, info = env.reset()
done = False
step = 0
episode_reward = 0
while not done:
step += 1
total_steps += 1
q_values = online_network(state)
action = select_action(q_values, total_steps, start=.9, end=.05, decay=1000)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay_buffer.push(state, action, reward, next_state, done)
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones = replay_buffer.sample(64)
q_values = online_network(states).gather(1, actions).squeeze(1)
with torch.no_grad():
# Obtain next actions for Q-target calculation
next_actions = ____.____.____
# Estimate next Q-values from these actions
next_q_values = ____.____.____
target_q_values = rewards + gamma * next_q_values * (1-dones)
loss = nn.MSELoss()(q_values, target_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_target_network(target_network, online_network, tau=.005)
state = next_state
episode_reward += reward
describe_episode(episode, reward, episode_reward, step)