Mixed precision training with basic PyTorch
You will use low precision floating point data types to speed up training for your language translation model. For example, 16-bit floating point data types (float16
) are only half the size of their 32-bit counterparts (float32
). This accelerates computations of matrix multiplications and convolutions. Recall that this involves scaling gradients and casting operations to 16 bit floating point.
Some objects have been preloaded: dataset
, model
, dataloader
, and optimizer
.
This exercise is part of the course
Efficient AI Model Training with PyTorch
Exercise instructions
- Before the loop, define a scaler for the gradients using a class from the
torch
library. - In the loop, cast operations to the 16-bit floating point data type using a context manager from the
torch
library. - In the loop, scale the loss and call
.backward()
to create scaled gradients.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
____.____(loss).____()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()