Training a precisione mista con PyTorch di base
Userai tipi di dato in virgola mobile a bassa precisione per velocizzare l'addestramento del tuo modello di traduzione automatica. Ad esempio, i tipi a 16 bit (float16) occupano solo metà della memoria rispetto ai corrispettivi a 32 bit (float32). Questo accelera i calcoli di moltiplicazioni di matrici e convoluzioni. Ricorda che ciò comporta lo scaling dei gradienti e il cast delle operazioni in virgola mobile a 16 bit.
Alcuni oggetti sono stati pre-caricati: dataset, model, dataloader e optimizer.
Questo esercizio fa parte del corso
Efficient AI Model Training with PyTorch
Istruzioni dell'esercizio
- Prima del loop, definisci uno scaler per i gradienti usando
torch.amp.GradScaler. - Nel loop, effettua il cast delle operazioni al tipo a 16 bit usando
torch.autocastcome context manager. - Nel loop, scala la loss e chiama
.backward()per creare gradienti scalati.
esercizio interattivo pratico
Prova questo esercizio completando questo codice di esempio.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
scaler.____(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()