Punto de control de gradiente con Trainer
Quieres utilizar el punto de control de gradiente para reducir la huella de memoria de tu modelo. Has visto cómo escribir el bucle de entrenamiento explícito con Accelerator, y ahora te gustaría utilizar una interfaz simplificada sin bucles de entrenamiento con Trainer. El ejercicio tardará algún tiempo en ejecutarse con la llamada a trainer.train().
Configura los argumentos para que Trainer utilice la comprobación de gradiente.
Este ejercicio forma parte del curso
Entrenamiento eficiente de modelos de IA con PyTorch
Instrucciones del ejercicio
- Utiliza cuatro pasos de acumulación de gradiente en
TrainingArguments. - Activa la comprobación de gradientes en
TrainingArguments. - Introduce los argumentos de entrenamiento en
Trainer.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
training_args = TrainingArguments(output_dir="./results",
evaluation_strategy="epoch",
# Use four gradient accumulation steps
gradient_accumulation_steps=____,
# Enable gradient checkpointing
____=____)
trainer = Trainer(model=model,
# Pass in the training arguments
____=____,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics)
trainer.train()