Ponto de verificação de gradiente com o Accelerator
Você continua a otimizar o uso da memória para poder treinar o modelo de tradução de idiomas no dispositivo. O acúmulo de gradiente ajudou você a treinar com eficiência em lotes maiores. Com base nesse trabalho, você pode adicionar o checkpointing de gradiente para reduzir o espaço de memória do seu modelo.
Os endereços model, train_dataloader e accelerator foram predefinidos.
Este exercício faz parte do curso
Treinamento eficiente de modelos de IA com PyTorch
Instruções do exercício
- Ative o ponto de verificação de gradiente no site
model. - Configure um gerenciador de contexto
Acceleratorpara permitir o acúmulo de gradiente no sitemodel.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Enable gradient checkpointing on the model
____.____()
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
# Get the outputs from a forward pass of the model
____ = ____(____, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
print(f"Loss = {loss}")