SGD local avec accélérateur
Vous avez mis en œuvre l'accumulation et le contrôle des gradients afin de rationaliser l'utilisation de la mémoire pour votre modèle de traduction linguistique. La formation étant toujours un peu lente, vous décidez d'ajouter des SGD locaux à votre boucle de formation afin d'améliorer l'efficacité de la communication entre les appareils. Construisez la boucle de formation avec les SGD locaux !
Les sites model
, train_dataloader
et accelerator
ont été prédéfinis et le site LocalSGD
a été importé.
Cet exercice fait partie du cours
Entraînement efficace de modèles d'IA avec PyTorch
Instructions
- Mettez en place un gestionnaire de contexte pour le SGD local et synchronisez les gradients tous les huit pas.
- Étape du gestionnaire de contexte SGD local.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
____.____()