1. Gradient accumulation
We've explored
2. Distributed training
distributed training, which can strain resources with large models or datasets, but
3. Efficient training
efficient training aims to optimize these resources.
4. Improving training efficiency
We'll examine three drivers of training efficiency: memory,
5. Improving training efficiency
communication,
6. Improving training efficiency
and computation.
7. Gradient accumulation improves memory efficiency
We'll start with gradient accumulation, which addresses memory efficiency.
8. The problem with large batch sizes
Training with large batch sizes produces better gradient estimates that help the model learn more quickly. However, GPU memory constrains batch sizes.
9. How does gradient accumulation work?
Gradient accumulation addresses this issue by summing gradients from smaller, memory-compatible batches, effectively simulating training on a larger batch. Then, the training loop updates model parameters with the summed gradients. Throughout the video, we effectively double the batch size by summing gradients every two batches for gradient accumulation.
10. PyTorch, Accelerator, and Trainer
We'll walk through each step using PyTorch for full training customization.
11. PyTorch, Accelerator, and Trainer
Then we'll simplify the training loop with Accelerator.
12. PyTorch, Accelerator, and Trainer
Finally, we'll cover Trainer as an easy interface with no training loop.
13. Gradient accumulation with PyTorch
Let's get started with PyTorch. First we load our tokenized customer review dataset and move it to a device.
14. Gradient accumulation with PyTorch
Following our diagram, we compute a forward pass and store the resulting outputs and loss.
15. Gradient accumulation with PyTorch
We divide the loss by two, or the number of batches that we sum over.
16. Gradient accumulation with PyTorch
Then we compute a backward pass, which produces gradients and accumulates the sum of loss over batches.
17. Gradient accumulation with PyTorch
After every two batches, which we check with a modulo (percent) operator,
18. Gradient accumulation with PyTorch
we update the model with summed gradients. Then the process repeats again, accumulating gradients every two batches.
19. From PyTorch to Accelerator
Using PyTorch provides a deeper understanding of how gradient accumulation works under the hood and allows us to fully customize our training loop for specialized use cases like reinforcement learning.
20. From PyTorch to Accelerator
Next, we'll see how Accelerator simplifies the training loop for gradient accumulation and handles device placement for distributed training.
21. Gradient accumulation with Accelerator
First we define Accelerator with the number of gradient accumulation steps. As before, we are summing gradients every two batches. Then Accelerator loads input data from the first batch onto available devices.
22. Gradient accumulation with Accelerator
We compute the forward pass and save the outputs and loss.
23. Gradient accumulation with Accelerator
Unlike the previous example, we don't need to explicitly scale the loss. Instead, we wrap the code inside of an .accumulate() context manager, which automatically scales loss and tracks batch number.
24. Gradient accumulation with Accelerator
Next we compute the backward pass, using accelerator.backward() this time.
25. Gradient accumulation with Accelerator
After two batches finish,
26. Gradient accumulation with Accelerator
we update the model with the summed gradients. Unlike before, we don't need to manually check when two batches complete.
27. From Accelerator to Trainer
We saw how Accelerator follows a simplified training loop structure.
28. From Accelerator to Trainer
Alternatively, Trainer provides an easy interface without training loops.
29. Gradient accumulation with Trainer
First we specify the number of gradient accumulation steps in TrainingArguments and then pass the arguments to Trainer. Calling the train() method begins training and prints metrics, showing training progress.
30. Let's practice!
Your turn to accumulate gradients!