Entraînement à la précision mixte avec PyTorch de base
Vous utiliserez des types de données en virgule flottante de faible précision pour accélérer l'apprentissage de votre modèle de traduction linguistique. Par exemple, les types de données à virgule flottante de 16 bits (float16) ne font que la moitié de la taille de leurs homologues de 32 bits (float32). Cela permet d'accélérer les calculs de multiplications de matrices et de convolutions. Rappelons qu'il s'agit de mettre à l'échelle les gradients et les opérations de coulée sur 16 bits en virgule flottante.
Certains objets ont été préchargés : dataset, model, dataloader, et optimizer.
Cet exercice fait partie du cours
Entraînement efficace de modèles d'IA avec PyTorch
Instructions
- Avant la boucle, définissez une échelle pour les gradients à l'aide d'une classe de la bibliothèque
torch. - Dans la boucle, faites passer les opérations au type de données à virgule flottante 16 bits en utilisant un gestionnaire de contexte de la bibliothèque
torch. - Dans la boucle, mettez à l'échelle la perte et appelez
.backward()pour créer des gradients mis à l'échelle.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
____.____(loss).____()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()