LoslegenKostenlos starten

Mixed-Precision-Training mit einfachem PyTorch

Du verwendest Gleitkomma-Datentypen mit niedrigerer Präzision, um das Training deines Übersetzungsmodells zu beschleunigen. Zum Beispiel sind 16-Bit-Gleitkommadatentypen (float16) nur halb so groß wie ihre 32-Bit-Pendants (float32). Das beschleunigt Berechnungen von Matrixmultiplikationen und Faltungen. Denk daran: Dazu müssen die Gradienten skaliert und Operationen in 16-Bit-Gleitkomma gecastet werden.

Einige Objekte wurden bereits geladen: dataset, model, dataloader und optimizer.

Diese Übung ist Teil des Kurses

<Kurs>Effizientes KI-Modelltraining mit PyTorch</Kurs>
Kurs ansehen

Übungsanweisungen

  • Definiere vor der Schleife einen Scaler für die Gradienten mit torch.amp.GradScaler.
  • Caste in der Schleife die Operationen mithilfe von torch.autocast als Kontext-Manager auf den 16-Bit-Gleitkomma-Datentyp.
  • Skaliere in der Schleife den Loss und rufe .backward() auf, um skalierte Gradienten zu erzeugen.

Interaktive praktische Übung

Versuche dich an dieser Übung, indem du diesen Beispielcode vervollständigst.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    scaler.____(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Code bearbeiten und ausführen