Local SGD avec Accelerator
Vous avez mis en place l’accumulation de gradients et le gradient checkpointing pour optimiser l’usage mémoire de votre modèle de traduction automatique. L’entraînement reste un peu lent ; vous décidez donc d’ajouter le local SGD à votre boucle d’entraînement pour améliorer l’efficacité des communications entre appareils. Construisez la boucle d’entraînement avec le local SGD !
Les objets model, train_dataloader et accelerator sont déjà définis, et LocalSGD a été importé.
Cet exercice fait partie du cours
<cours>Entraîner efficacement des modèles d’IA avec PyTorch</cours>Instructions de l’exercice
- Définissez
local_sgd_stepspour synchroniser les gradients toutes les huit étapes. - Exécutez une itération du gestionnaire de contexte local SGD.
Exercice interactif pratique
Essayez cet exercice en complétant ce code d’exemple.
# Set up a context manager to synchronize gradients every eight steps
with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
local_sgd.____()