Get startedGet started for free

Mixed precision training with basic PyTorch

You will use low precision floating point data types to speed up training for your language translation model. For example, 16-bit floating point data types (float16) are only half the size of their 32-bit counterparts (float32). This accelerates computations of matrix multiplications and convolutions. Recall that this involves scaling gradients and casting operations to 16 bit floating point.

Some objects have been preloaded: dataset, model, dataloader, and optimizer.

This exercise is part of the course

Efficient AI Model Training with PyTorch

View Course

Exercise instructions

  • Before the loop, define a scaler for the gradients using a class from the torch library.
  • In the loop, cast operations to the 16-bit floating point data type using a context manager from the torch library.
  • In the loop, scale the loss and call .backward() to create scaled gradients.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    ____.____(loss).____()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Edit and Run Code