Local SGD avec Accelerator
Vous avez mis en place l’accumulation de gradients et le gradient checkpointing pour optimiser l’utilisation de la mémoire de votre modèle de traduction. L’entraînement reste un peu lent, vous décidez donc d’ajouter le local SGD à votre boucle d’entraînement pour améliorer l’efficacité des communications entre appareils. Construisez la boucle d’entraînement avec le local SGD !
Les objets model, train_dataloader et accelerator ont été pré-définis, et LocalSGD a été importé.
Cet exercice fait partie du cours
Entraîner efficacement des modèles d’IA avec PyTorch
Instructions
- Configurez un gestionnaire de contexte pour le local SGD et synchronisez les gradients toutes les huit étapes.
- Faites avancer le gestionnaire de contexte local SGD d’un pas.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
____.____()