Ponto de controle de gradiente com o Trainer
Você deseja usar o checkpointing de gradiente para reduzir o espaço de memória do seu modelo. Você viu como escrever o loop de treinamento explícito com Accelerator, e agora gostaria de usar uma interface simplificada sem loops de treinamento com Trainer. O exercício levará algum tempo para ser executado com a chamada para trainer.train().
Configure os argumentos para que o site Trainer use o gradient checkpointing.
Este exercício faz parte do curso
Treinamento eficiente de modelos de IA com PyTorch
Instruções do exercício
- Use quatro etapas de acumulação de gradiente em
TrainingArguments. - Habilite o ponto de verificação de gradiente em
TrainingArguments. - Passe os argumentos de treinamento para
Trainer.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
training_args = TrainingArguments(output_dir="./results",
evaluation_strategy="epoch",
# Use four gradient accumulation steps
gradient_accumulation_steps=____,
# Enable gradient checkpointing
____=____)
trainer = Trainer(model=model,
# Pass in the training arguments
____=____,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics)
trainer.train()