1. Gradient checkpointing and local SGD
Let's continue our journey to make distributed training more efficient.
2. Improving training efficiency
We saw key drivers of efficiency.
3. Gradient checkpointing improves memory efficiency
Now, we'll see how gradient checkpointing improves memory efficiency
4. Local SGD addresses communication efficiency
and how local stochastic gradient descent, or SGD, addresses communication efficiency.
5. What is gradient checkpointing?
Gradient checkpointing reduces memory usage by selectively saving activations during the forward pass. For example, our model needs to calculate A + B = C.
6. What is gradient checkpointing?
We need to compute A and B first before computing C.
7. What is gradient checkpointing?
Once we compute C, we don't need A and B for other calculations in the forward pass. The key question is whether we save or remove A and B.
8. What is gradient checkpointing?
Without gradient checkpointing, we save A and B in memory.
9. What is gradient checkpointing?
With gradient checkpointing, we discard A and B after computing C to save memory.
10. What is gradient checkpointing?
During the backward pass, we recompute the gradients of A and B. By discarding A and B, we save memory at the expense of recomputation. Gradient checkpointing follows a similar process, and it will store activations that are expensive to recompute during the backward pass.
11. What is gradient checkpointing?
For example, if B is expensive to recompute, then it will keep B and discard A after computing C.
12. Trainer and Accelerator
We'll examine Trainer and Accelerator as two options for gradient checkpointing.
13. Trainer and Accelerator
As we saw, Trainer fits applications that don't require custom training loops, like most Transformer models from the Hugging Face library.
14. Gradient checkpointing with Trainer
For Trainer, the process is straightforward. Let's build on our TrainingArguments from gradient accumulation.
15. Gradient checkpointing with Trainer
We add an argument for gradient_checkpointing, setting it to True. Then we define Trainer, passing it training_args. Calling trainer.train() will display metrics for each epoch, like accuracy and loss.
16. From Trainer to Accelerator
Now let's examine training loops in Accelerator. Accelerator helps with custom Transformer architectures or advanced training techniques.
17. Gradient checkpointing with Accelerator
We'll modify our training loop from gradient accumulation with one line.
18. Gradient checkpointing with Accelerator
We enable gradient checkpointing by calling the model's gradient_checkpointing_enable() method. Now we are using both gradient accumulation and gradient checkpointing in our training loop.
19. Local SGD improves communication efficiency
Next, let's examine how local stochastic gradient descent, or SGD, optimizes communication efficiency.
20. What is local SGD?
During training, each device computes gradients in parallel.
21. What is local SGD?
A driver node collects all device gradients and updates model parameters on each device; this process is called gradient synchronization. Local SGD reduces the frequency of gradient synchronization to speed up training. The alternative is to synchronize gradients during every step, which increases training time.
22. Local SGD with Accelerator
Let's build on our previous loop from gradient checkpointing to enable local SGD.
23. Local SGD with Accelerator
We wrap the loop with a LocalSGD context manager, where we specify when to synchronize gradients: in this case, every eight steps, instead of automatically synchronizing every step.
24. Local SGD with Accelerator
In the loop, we add a call to local_sgd.step() for LocalSGD to track the batch number.
25. Let's practice!
Time for a checkpoint of your understanding!