MulaiMulai sekarang secara gratis

Mixed precision training with basic PyTorch

You will use low precision floating point data types to speed up training for your language translation model. For example, 16-bit floating point data types (float16) are only half the size of their 32-bit counterparts (float32). This accelerates computations of matrix multiplications and convolutions. Recall that this involves scaling gradients and casting operations to 16 bit floating point.

Some objects have been preloaded: dataset, model, dataloader, and optimizer.

Latihan ini adalah bagian dari kursus

Efficient AI Model Training with PyTorch

Lihat Kursus

Petunjuk latihan

  • Before the loop, define a scaler for the gradients using a class from the torch library.
  • In the loop, cast operations to the 16-bit floating point data type using a context manager from the torch library.
  • In the loop, scale the loss and call .backward() to create scaled gradients.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    ____.____(loss).____()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Edit dan Jalankan Kode