ComeçarComece de graça

Treinamento de precisão mista com PyTorch básico

Você usará tipos de dados de ponto flutuante de baixa precisão para acelerar o treinamento do seu modelo de tradução de idiomas. Por exemplo, os tipos de dados de ponto flutuante de 16 bits (float16) têm apenas metade do tamanho de seus equivalentes de 32 bits (float32). Isso acelera os cálculos de multiplicações e convoluções de matrizes. Lembre-se de que isso envolve gradientes de escala e operações de conversão para ponto flutuante de 16 bits.

Alguns objetos foram pré-carregados: dataset, model, dataloader, e optimizer.

Este exercício faz parte do curso

Treinamento eficiente de modelos de IA com PyTorch

Ver curso

Instruções do exercício

  • Antes do loop, defina um escalonador para os gradientes usando uma classe da biblioteca torch.
  • No loop, converta as operações para o tipo de dados de ponto flutuante de 16 bits usando um gerenciador de contexto da biblioteca torch.
  • No loop, dimensione a perda e chame .backward() para criar gradientes dimensionados.

Exercício interativo prático

Experimente este exercício completando este código de exemplo.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    ____.____(loss).____()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Editar e executar o código