Mixed-Precision-Training mit einfachem PyTorch
Du wirst Gleitkomma-Datentypen mit geringer Präzision verwenden, um das Training deines Übersetzungsmodells zu beschleunigen. Zum Beispiel sind 16-Bit-Gleitkomma-Datentypen (float16) nur halb so groß wie ihre 32-Bit-Pendants (float32). Das beschleunigt Matrixmultiplikationen und Faltungsoperationen. Denk daran: Dabei werden die Gradienten skaliert und Operationen auf 16-Bit-Gleitkomma gecastet.
Einige Objekte wurden bereits geladen: dataset, model, dataloader und optimizer.
Diese Übung ist Teil des Kurses
Effizientes KI-Modelltraining mit PyTorch
Anleitung zur Übung
- Definiere vor der Schleife einen Scaler für die Gradienten mithilfe einer Klasse aus der
torch-Bibliothek. - Caste in der Schleife die Operationen mithilfe eines Kontextmanagers aus der
torch-Bibliothek auf den 16-Bit-Gleitkomma-Datentyp. - Skaliere in der Schleife den Loss und rufe
.backward()auf, um skalierte Gradienten zu erzeugen.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
____.____(loss).____()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()