Get startedGet started for free

Gradient accumulation

1. Gradient accumulation

We've explored

2. Distributed training

distributed training, which can strain resources with large models or datasets, but

3. Efficient training

efficient training aims to optimize these resources.

4. Improving training efficiency

We'll examine three drivers of training efficiency: memory,

5. Improving training efficiency

communication,

6. Improving training efficiency

and computation.

7. Gradient accumulation improves memory efficiency

We'll start with gradient accumulation, which addresses memory efficiency.

8. The problem with large batch sizes

Training with large batch sizes produces better gradient estimates that help the model learn more quickly. However, GPU memory constrains batch sizes.

9. How does gradient accumulation work?

Gradient accumulation addresses this issue by summing gradients from smaller, memory-compatible batches, effectively simulating training on a larger batch. Then, the training loop updates model parameters with the summed gradients. Throughout the video, we effectively double the batch size by summing gradients every two batches for gradient accumulation.

10. PyTorch, Accelerator, and Trainer

We'll walk through each step using PyTorch for full training customization.

11. PyTorch, Accelerator, and Trainer

Then we'll simplify the training loop with Accelerator.

12. PyTorch, Accelerator, and Trainer

Finally, we'll cover Trainer as an easy interface with no training loop.

13. Gradient accumulation with PyTorch

Let's get started with PyTorch. First we load our tokenized customer review dataset and move it to a device.

14. Gradient accumulation with PyTorch

Following our diagram, we compute a forward pass and store the resulting outputs and loss.

15. Gradient accumulation with PyTorch

We divide the loss by two, or the number of batches that we sum over.

16. Gradient accumulation with PyTorch

Then we compute a backward pass, which produces gradients and accumulates the sum of loss over batches.

17. Gradient accumulation with PyTorch

After every two batches, which we check with a modulo (percent) operator,

18. Gradient accumulation with PyTorch

we update the model with summed gradients. Then the process repeats again, accumulating gradients every two batches.

19. From PyTorch to Accelerator

Using PyTorch provides a deeper understanding of how gradient accumulation works under the hood and allows us to fully customize our training loop for specialized use cases like reinforcement learning.

20. From PyTorch to Accelerator

Next, we'll see how Accelerator simplifies the training loop for gradient accumulation and handles device placement for distributed training.

21. Gradient accumulation with Accelerator

First we define Accelerator with the number of gradient accumulation steps. As before, we are summing gradients every two batches. Then Accelerator loads input data from the first batch onto available devices.

22. Gradient accumulation with Accelerator

We compute the forward pass and save the outputs and loss.

23. Gradient accumulation with Accelerator

Unlike the previous example, we don't need to explicitly scale the loss. Instead, we wrap the code inside of an .accumulate() context manager, which automatically scales loss and tracks batch number.

24. Gradient accumulation with Accelerator

Next we compute the backward pass, using accelerator.backward() this time.

25. Gradient accumulation with Accelerator

After two batches finish,

26. Gradient accumulation with Accelerator

we update the model with the summed gradients. Unlike before, we don't need to manually check when two batches complete.

27. From Accelerator to Trainer

We saw how Accelerator follows a simplified training loop structure.

28. From Accelerator to Trainer

Alternatively, Trainer provides an easy interface without training loops.

29. Gradient accumulation with Trainer

First we specify the number of gradient accumulation steps in TrainingArguments and then pass the arguments to Trainer. Calling the train() method begins training and prints metrics, showing training progress.

30. Let's practice!

Your turn to accumulate gradients!