Local SGD with Accelerator
You've implemented gradient accumulation and gradient checkpointing to streamline memory usage for your language translation model. Training is still a bit slow, so you decide to add local SGD to your training loop to improve communication efficiency between devices. Build the training loop with local SGD!
The model, train_dataloader, and accelerator have been pre-defined, and LocalSGD has been imported.
This exercise is part of the course
Efficient AI Model Training with PyTorch
Exercise instructions
- Set
local_sgd_stepsto synchronize gradients every eight steps. - Step the local SGD context manager.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Set up a context manager to synchronize gradients every eight steps
with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
local_sgd.____()