Inizia subitoInizia gratis

Training a precisione mista con PyTorch di base

Userai tipi di dato in virgola mobile a bassa precisione per velocizzare l'addestramento del tuo modello di traduzione automatica. Ad esempio, i tipi a 16 bit (float16) occupano solo metà della memoria rispetto ai corrispettivi a 32 bit (float32). Questo accelera i calcoli di moltiplicazioni di matrici e convoluzioni. Ricorda che ciò comporta lo scaling dei gradienti e il cast delle operazioni in virgola mobile a 16 bit.

Alcuni oggetti sono stati pre-caricati: dataset, model, dataloader e optimizer.

Questo esercizio fa parte del corso

Efficient AI Model Training with PyTorch

Visualizza corso

Istruzioni dell'esercizio

  • Prima del loop, definisci uno scaler per i gradienti usando torch.amp.GradScaler.
  • Nel loop, effettua il cast delle operazioni al tipo a 16 bit usando torch.autocast come context manager.
  • Nel loop, scala la loss e chiama .backward() per creare gradienti scalati.

esercizio interattivo pratico

Prova questo esercizio completando questo codice di esempio.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    scaler.____(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Modifica ed esegui il codice