Get startedGet started for free

Gradient checkpointing and local SGD

1. Gradient checkpointing and local SGD

Let's continue our journey to make distributed training more efficient.

2. Improving training efficiency

We saw key drivers of efficiency.

3. Gradient checkpointing improves memory efficiency

Now, we'll see how gradient checkpointing improves memory efficiency

4. Local SGD addresses communication efficiency

and how local stochastic gradient descent, or SGD, addresses communication efficiency.

5. What is gradient checkpointing?

Gradient checkpointing reduces memory usage by selectively saving activations during the forward pass. For example, our model needs to calculate A + B = C.

6. What is gradient checkpointing?

We need to compute A and B first before computing C.

7. What is gradient checkpointing?

Once we compute C, we don't need A and B for other calculations in the forward pass. The key question is whether we save or remove A and B.

8. What is gradient checkpointing?

Without gradient checkpointing, we save A and B in memory.

9. What is gradient checkpointing?

With gradient checkpointing, we discard A and B after computing C to save memory.

10. What is gradient checkpointing?

During the backward pass, we recompute the gradients of A and B. By discarding A and B, we save memory at the expense of recomputation. Gradient checkpointing follows a similar process, and it will store activations that are expensive to recompute during the backward pass.

11. What is gradient checkpointing?

For example, if B is expensive to recompute, then it will keep B and discard A after computing C.

12. Trainer and Accelerator

We'll examine Trainer and Accelerator as two options for gradient checkpointing.

13. Trainer and Accelerator

As we saw, Trainer fits applications that don't require custom training loops, like most Transformer models from the Hugging Face library.

14. Gradient checkpointing with Trainer

For Trainer, the process is straightforward. Let's build on our TrainingArguments from gradient accumulation.

15. Gradient checkpointing with Trainer

We add an argument for gradient_checkpointing, setting it to True. Then we define Trainer, passing it training_args. Calling trainer.train() will display metrics for each epoch, like accuracy and loss.

16. From Trainer to Accelerator

Now let's examine training loops in Accelerator. Accelerator helps with custom Transformer architectures or advanced training techniques.

17. Gradient checkpointing with Accelerator

We'll modify our training loop from gradient accumulation with one line.

18. Gradient checkpointing with Accelerator

We enable gradient checkpointing by calling the model's gradient_checkpointing_enable() method. Now we are using both gradient accumulation and gradient checkpointing in our training loop.

19. Local SGD improves communication efficiency

Next, let's examine how local stochastic gradient descent, or SGD, optimizes communication efficiency.

20. What is local SGD?

During training, each device computes gradients in parallel.

21. What is local SGD?

A driver node collects all device gradients and updates model parameters on each device; this process is called gradient synchronization. Local SGD reduces the frequency of gradient synchronization to speed up training. The alternative is to synchronize gradients during every step, which increases training time.

22. Local SGD with Accelerator

Let's build on our previous loop from gradient checkpointing to enable local SGD.

23. Local SGD with Accelerator

We wrap the loop with a LocalSGD context manager, where we specify when to synchronize gradients: in this case, every eight steps, instead of automatically synchronizing every step.

24. Local SGD with Accelerator

In the loop, we add a call to local_sgd.step() for LocalSGD to track the batch number.

25. Let's practice!

Time for a checkpoint of your understanding!