Get startedGet started for free

Local SGD with Accelerator

You've implemented gradient accumulation and gradient checkpointing to streamline memory usage for your language translation model. Training is still a bit slow, so you decide to add local SGD to your training loop to improve communication efficiency between devices. Build the training loop with local SGD!

The model, train_dataloader, and accelerator have been pre-defined, and LocalSGD has been imported.

This exercise is part of the course

Efficient AI Model Training with PyTorch

View Course

Exercise instructions

  • Set up a context manager for local SGD, and synchronize gradients every eight steps.
  • Step the local SGD context manager.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
    for batch in train_dataloader:
        with accelerator.accumulate(model):
            inputs, targets = batch["input_ids"], batch["labels"]
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            # Step the local SGD context manager
            ____.____()
Edit and Run Code