CommencerCommencer gratuitement

A2C with batch updates

In this course so far, you have been using variations around the same core DRL training loop. In practice, there are a number of ways in which this structure can be extended, for example to accommodate batch updates.

You will now revisit the A2C training loop on the Lunar Lander environment, but instead of updating the networks at every step, you will wait until 10 steps have elapsed before running the gradient descent step. By averaging the losses over 10 steps, you will benefit from slightly more stable updates.

Cet exercice fait partie du cours

Deep Reinforcement Learning in Python

Afficher le cours

Instructions

  • Append the losses from each step to the loss tensors for the current batch.
  • Calculate the batch losses.
  • Reinitialize the loss tensors.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

actor_losses = torch.tensor([])
critic_losses = torch.tensor([])
for episode in range(10):
    state, info = env.reset()
    done = False
    episode_reward = 0
    step = 0
    while not done:
        step += 1
        action, action_log_prob = select_action(actor, state)                
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        actor_loss, critic_loss = calculate_losses(
            critic, action_log_prob, 
            reward, state, next_state, done)
        # Append to the loss tensors
        actor_losses = torch.cat((____, ____))
        critic_losses = torch.cat((____, ____))
        if len(actor_losses) >= 10:
            # Calculate the batch losses
            actor_loss_batch = actor_losses.____
            critic_loss_batch = critic_losses.____
            actor_optimizer.zero_grad(); actor_loss_batch.backward(); actor_optimizer.step()
            critic_optimizer.zero_grad(); critic_loss_batch.backward(); critic_optimizer.step()
            # Reinitialize the loss tensors
            actor_losses = ____
            critic_losses = ____
        state = next_state
    describe_episode(episode, reward, episode_reward, step)
Modifier et exécuter le code