Mixed precision training with basic PyTorch
You will use low precision floating point data types to speed up training for your language translation model. For example, 16-bit floating point data types (float16) are only half the size of their 32-bit counterparts (float32). This accelerates computations of matrix multiplications and convolutions. Recall that this involves scaling gradients and casting operations to 16 bit floating point.
Some objects have been preloaded: dataset, model, dataloader, and optimizer.
This exercise is part of the course
Efficient AI Model Training with PyTorch
Exercise instructions
- Before the loop, define a scaler for the gradients using a class from the
torchlibrary. - In the loop, cast operations to the 16-bit floating point data type using a context manager from the
torchlibrary. - In the loop, scale the loss and call
.backward()to create scaled gradients.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
____.____(loss).____()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()