Gradient checkpointing con Trainer
Quieres usar gradient checkpointing para reducir el consumo de memoria de tu modelo. Ya has visto cómo escribir el bucle de entrenamiento explícito con Accelerator, y ahora te gustaría usar una interfaz simplificada sin bucles de entrenamiento con Trainer. El ejercicio tardará un poco en ejecutarse con la llamada a trainer.train().
Configura los argumentos de Trainer para usar gradient checkpointing.
Este ejercicio forma parte del curso
Entrenamiento eficiente de modelos de IA con PyTorch
Instrucciones del ejercicio
- Usa cuatro pasos de acumulación de gradientes en
TrainingArguments. - Activa el gradient checkpointing en
TrainingArguments. - Pasa los argumentos de entrenamiento a
Trainer.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
training_args = TrainingArguments(output_dir="./results",
evaluation_strategy="epoch",
# Use four gradient accumulation steps
gradient_accumulation_steps=____,
# Enable gradient checkpointing
____=____)
trainer = Trainer(model=model,
# Pass in the training arguments
____=____,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics)
trainer.train()