CommencerCommencer gratuitement

Entraînement à la précision mixte avec PyTorch de base

Vous utiliserez des types de données en virgule flottante de faible précision pour accélérer l'apprentissage de votre modèle de traduction linguistique. Par exemple, les types de données à virgule flottante de 16 bits (float16) ne font que la moitié de la taille de leurs homologues de 32 bits (float32). Cela permet d'accélérer les calculs de multiplications de matrices et de convolutions. Rappelons qu'il s'agit de mettre à l'échelle les gradients et les opérations de coulée sur 16 bits en virgule flottante.

Certains objets ont été préchargés : dataset, model, dataloader, et optimizer.

Cet exercice fait partie du cours

Entraînement efficace de modèles d'IA avec PyTorch

Afficher le cours

Instructions

  • Avant la boucle, définissez une échelle pour les gradients à l'aide d'une classe de la bibliothèque torch.
  • Dans la boucle, faites passer les opérations au type de données à virgule flottante 16 bits en utilisant un gestionnaire de contexte de la bibliothèque torch.
  • Dans la boucle, mettez à l'échelle la perte et appelez .backward() pour créer des gradients mis à l'échelle.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    ____.____(loss).____()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Modifier et exécuter le code