Gradient checkpointing con Trainer
Quieres usar gradient checkpointing para reducir el consumo de memoria de tu modelo. Ya has visto cómo escribir el bucle de entrenamiento explícito con Accelerator, y ahora te gustaría usar una interfaz simplificada sin bucles de entrenamiento con Trainer. El ejercicio tardará un poco en ejecutarse con la llamada a trainer.train().
Configura los argumentos de Trainer para usar gradient checkpointing.
Este ejercicio forma parte del curso
Entrenamiento eficiente de modelos de IA con PyTorch
Instrucciones del ejercicio
- Usa cuatro pasos de acumulación de gradientes en
TrainingArguments. - Activa el gradient checkpointing en
TrainingArguments. - Pasa los argumentos de entrenamiento a
Trainer.
ejercicio interactivo práctico
Prueba este ejercicio completando este código de ejemplo.
training_args = TrainingArguments(output_dir="./results",
evaluation_strategy="epoch",
# Use four gradient accumulation steps
gradient_accumulation_steps=____,
# Enable gradient checkpointing
gradient_checkpointing=____)
trainer = Trainer(model=model,
# Pass in the training arguments
args=____,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics)
trainer.train()