1. Batch updates in policy gradient
Welcome back!
2. Stepwise vs batch gradient updates
In our study of A2C and PPO, we have calculated our loss based on only one step.
3. Stepwise vs batch gradient updates
We select an action,
4. Stepwise vs batch gradient updates
iterate the environment,
5. Stepwise vs batch gradient updates
calculate the loss and update the policy,
6. Stepwise vs batch gradient updates
one step at a time.
7. Stepwise vs batch gradient updates
This works. But in practice, updating the policy over a batch of steps rather than at every step is typically preferable, much like how experience replay enhances DQN performance. It stabilizes the training process and leverages the parallel processing capabilities of modern hardware.
We will now visit different ways to introduce batch updates for A2C and PPO.
8. Batching the A2C / PPO updates
In A2C, instead of calculating the loss and updating the policy at every step, a more common approach is to accumulate updates over the course of a sequence of steps called a rollout before updating the policy.
On this image, we illustrate what the algorithm looks like with a rollout length of 2 steps.
9. Batching the A2C / PPO updates
We iterate through steps,
10. Batching the A2C / PPO updates
until the rollout is complete:
11. Batching the A2C / PPO updates
at which point we calculate the average loss over the duration of the rollout, and perform gradient descent to update the policy.
12. Batching the A2C / PPO updates
In practice, a rollout can be much longer, spanning one or several episodes. We iterate like this over rollouts.
13. The A2C training loop with batch updates
Let's reexplore the A2C training loop to introduce these batch updates. We could follow the exact same approach with the PPO loss function.
We are going to accumulate the losses in tensors, which we will fill step by step; so we initialize two empty tensors for the actor and the critic with torch.tensor([]).
We then start the outer loop over episodes and the inner loop over steps.
14. The A2C training loop with batch updates
At the end of each step we calculate the loss and append it to our batch with torch.cat, but we do not perform gradient descent until the batch is full.
When that happens, we take the average loss with the mean method, perform a gradient descent step, and reinitialize the losses for the next batch. We do that for both the actor and the critic networks.
15. A2C / PPO with multiple agents
Let's explore further concepts around batch processing. We'll introduce this topic at a high level but a full implementation is out of scope for this course. Instead, we'll cover the concepts that will set you up to work with more complicated frameworks in the future, featuring multiple agents, minibatches, or multiple epochs.
It is common practice, with A2C and PPO, to have multiple agents run the same policy independently in parallel to accumulate diverse experiences. This diversity can further reduce correlation within a batch and improve learning stability.
16. Rollouts and minibatches
To overcome memory constraints, we can also divide large rollouts into minibatches, shuffling the rollout data first to reduce the correlation within a minibatch.
In this case, we update the policy several times per rollout. This can be problematic, as the updated policy starts diverging from the one that generated the rollouts. A2C is quite sensitive to this; but PPO is resilient, as its clipped objective function prevents large policy deviations.
17. PPO with multiple epochs
PPO can even learn from multiple epochs based on the same set of rollouts. Being able to perform multiple updates on each data point makes PPO particularly sample efficient. The setup as illustrated, with multiple agents and multiple epochs, is the default implementation for PPO in most DRL frameworks.
18. Let's practice!
Let's practice!