LoslegenKostenlos loslegen

Mixed-Precision-Training mit einfachem PyTorch

Du wirst Gleitkomma-Datentypen mit geringer Präzision verwenden, um das Training deines Übersetzungsmodells zu beschleunigen. Zum Beispiel sind 16-Bit-Gleitkomma-Datentypen (float16) nur halb so groß wie ihre 32-Bit-Pendants (float32). Das beschleunigt Matrixmultiplikationen und Faltungsoperationen. Denk daran: Dabei werden die Gradienten skaliert und Operationen auf 16-Bit-Gleitkomma gecastet.

Einige Objekte wurden bereits geladen: dataset, model, dataloader und optimizer.

Diese Übung ist Teil des Kurses

Effizientes KI-Modelltraining mit PyTorch

Kurs anzeigen

Anleitung zur Übung

  • Definiere vor der Schleife einen Scaler für die Gradienten mithilfe einer Klasse aus der torch-Bibliothek.
  • Caste in der Schleife die Operationen mithilfe eines Kontextmanagers aus der torch-Bibliothek auf den 16-Bit-Gleitkomma-Datentyp.
  • Skaliere in der Schleife den Loss und rufe .backward() auf, um skalierte Gradienten zu erzeugen.

Interaktive Übung

Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    ____.____(loss).____()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Code bearbeiten und ausführen