ComenzarEmpieza gratis

Gradient checkpointing con Trainer

Quieres usar gradient checkpointing para reducir el consumo de memoria de tu modelo. Ya has visto cómo escribir el bucle de entrenamiento explícito con Accelerator, y ahora te gustaría usar una interfaz simplificada sin bucles de entrenamiento con Trainer. El ejercicio tardará un poco en ejecutarse con la llamada a trainer.train().

Configura los argumentos de Trainer para usar gradient checkpointing.

Este ejercicio forma parte del curso

Entrenamiento eficiente de modelos de IA con PyTorch

Ver curso

Instrucciones del ejercicio

  • Usa cuatro pasos de acumulación de gradientes en TrainingArguments.
  • Activa el gradient checkpointing en TrainingArguments.
  • Pasa los argumentos de entrenamiento a Trainer.

Ejercicio interactivo práctico

Prueba este ejercicio y completa el código de muestra.

training_args = TrainingArguments(output_dir="./results",
                                  evaluation_strategy="epoch",
                                  # Use four gradient accumulation steps
                                  gradient_accumulation_steps=____,
                                  # Enable gradient checkpointing
                                  ____=____)
trainer = Trainer(model=model,
                  # Pass in the training arguments
                  ____=____,
                  train_dataset=dataset["train"],
                  eval_dataset=dataset["validation"],
                  compute_metrics=compute_metrics)
trainer.train()
Editar y ejecutar código