Mixed-Precision-Training mit einfachem PyTorch
Du verwendest Gleitkomma-Datentypen mit niedrigerer Präzision, um das Training deines Übersetzungsmodells zu beschleunigen. Zum Beispiel sind 16-Bit-Gleitkommadatentypen (float16) nur halb so groß wie ihre 32-Bit-Pendants (float32). Das beschleunigt Berechnungen von Matrixmultiplikationen und Faltungen. Denk daran: Dazu müssen die Gradienten skaliert und Operationen in 16-Bit-Gleitkomma gecastet werden.
Einige Objekte wurden bereits geladen: dataset, model, dataloader und optimizer.
Diese Übung ist Teil des Kurses
<Kurs>Effizientes KI-Modelltraining mit PyTorch</Kurs>Übungsanweisungen
- Definiere vor der Schleife einen Scaler für die Gradienten mit
torch.amp.GradScaler. - Caste in der Schleife die Operationen mithilfe von
torch.autocastals Kontext-Manager auf den 16-Bit-Gleitkomma-Datentyp. - Skaliere in der Schleife den Loss und rufe
.backward()auf, um skalierte Gradienten zu erzeugen.
Interaktive praktische Übung
Versuche dich an dieser Übung, indem du diesen Beispielcode vervollständigst.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
scaler.____(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()