IniziaInizia gratis

A2C con aggiornamenti in batch

Finora, in questo corso hai usato varianti dello stesso ciclo di training DRL di base. Nella pratica, ci sono diversi modi per estendere questa struttura, ad esempio per supportare aggiornamenti in batch.

Ora rivedrai il ciclo di training A2C nell'ambiente Lunar Lander, ma invece di aggiornare le reti a ogni passo, aspetterai che trascorrano 10 passi prima di eseguire lo step di discesa del gradiente. Facendo la media delle loss su 10 passi, otterrai aggiornamenti un po' più stabili.

Questo esercizio fa parte del corso

Deep Reinforcement Learning in Python

Visualizza il corso

Istruzioni dell'esercizio

  • Aggiungi le loss di ogni passo ai tensori delle loss per il batch corrente.
  • Calcola le loss del batch.
  • Reinizializza i tensori delle loss.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

actor_losses = torch.tensor([])
critic_losses = torch.tensor([])
for episode in range(10):
    state, info = env.reset()
    done = False
    episode_reward = 0
    step = 0
    while not done:
        step += 1
        action, action_log_prob = select_action(actor, state)                
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        actor_loss, critic_loss = calculate_losses(
            critic, action_log_prob, 
            reward, state, next_state, done)
        # Append to the loss tensors
        actor_losses = torch.cat((____, ____))
        critic_losses = torch.cat((____, ____))
        if len(actor_losses) >= 10:
            # Calculate the batch losses
            actor_loss_batch = actor_losses.____
            critic_loss_batch = critic_losses.____
            actor_optimizer.zero_grad(); actor_loss_batch.backward(); actor_optimizer.step()
            critic_optimizer.zero_grad(); critic_loss_batch.backward(); critic_optimizer.step()
            # Reinitialize the loss tensors
            actor_losses = ____
            critic_losses = ____
        state = next_state
    describe_episode(episode, reward, episode_reward, step)
Modifica ed esegui il codice