Treinamento em precisão mista com PyTorch básico
Você vai usar tipos de ponto flutuante de baixa precisão para acelerar o treinamento do seu modelo de tradução de linguagem. Por exemplo, tipos de ponto flutuante de 16 bits (float16) têm apenas metade do tamanho dos de 32 bits (float32). Isso acelera os cálculos de multiplicações de matrizes e convoluções. Lembre-se de que isso envolve escalonar gradientes e converter operações para ponto flutuante de 16 bits.
Alguns objetos já foram pré-carregados: dataset, model, dataloader e optimizer.
Este exercício faz parte do curso
Treinamento Eficiente de Modelos de IA com PyTorch
Instruções do exercício
- Antes do loop, defina um scaler para os gradientes usando uma classe da biblioteca
torch. - No loop, converta as operações para o tipo de ponto flutuante de 16 bits usando um gerenciador de contexto da biblioteca
torch. - No loop, escale a loss e chame
.backward()para criar gradientes escalonados.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
inputs, targets = batch["input_ids"], batch["labels"]
# Casts operations to mixed precision
with torch.____(device_type="cpu", dtype=torch.____):
outputs = model(inputs, labels=targets)
loss = outputs.loss
# Compute scaled gradients
____.____(loss).____()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()