Local SGD with Accelerator
You've implemented gradient accumulation and gradient checkpointing to streamline memory usage for your language translation model. Training is still a bit slow, so you decide to add local SGD to your training loop to improve communication efficiency between devices. Build the training loop with local SGD!
The model
, train_dataloader
, and accelerator
have been pre-defined, and LocalSGD
has been imported.
This exercise is part of the course
Efficient AI Model Training with PyTorch
Exercise instructions
- Set up a context manager for local SGD, and synchronize gradients every eight steps.
- Step the local SGD context manager.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
____.____()