Comece agoraComece grátis

Treinamento com precisão mista no PyTorch básico

Você vai usar tipos de ponto flutuante de baixa precisão para acelerar o treinamento do seu modelo de tradução de linguagem. Por exemplo, tipos de ponto flutuante de 16 bits (float16) têm apenas metade do tamanho de seus equivalentes de 32 bits (float32). Isso acelera os cálculos de multiplicações de matrizes e convoluções. Lembre-se de que isso envolve escalar os gradientes e converter operações para ponto flutuante de 16 bits.

Alguns objetos já foram carregados: dataset, model, dataloader e optimizer.

Este exercicio faz parte do curso

Treinamento Eficiente de Modelos de IA com PyTorch

Ver curso

Instruções do exercicio

  • Antes do loop, defina um scaler para os gradientes usando torch.amp.GradScaler.
  • No loop, converta as operações para o tipo de ponto flutuante de 16 bits usando torch.autocast como gerenciador de contexto.
  • No loop, escale a perda e chame .backward() para criar gradientes escalonados.

exercicio interativo prático

Tente este exercicio completando este código de exemplo.

# Define a scaler for the gradients
scaler = torch.amp.____()
for batch in train_dataloader:
    inputs, targets = batch["input_ids"], batch["labels"]
    # Casts operations to mixed precision
    with torch.____(device_type="cpu", dtype=torch.____):
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
    # Compute scaled gradients
    scaler.____(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
Editar e Executar Código