Gradient checkpointing with Trainer
You want to use gradient checkpointing to reduce the memory footprint of your model. You've seen how to write the explicit training loop with Accelerator
, and now you'd like to use a simplified interface without training loops with Trainer
. The exercise will take some time to run with the call to trainer.train()
.
Set up the arguments for Trainer
to use gradient checkpointing.
This exercise is part of the course
Efficient AI Model Training with PyTorch
Exercise instructions
- Use four gradient accumulation steps in
TrainingArguments
. - Enable gradient checkpointing in
TrainingArguments
. - Pass in the training arguments to
Trainer
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
training_args = TrainingArguments(output_dir="./results",
evaluation_strategy="epoch",
# Use four gradient accumulation steps
gradient_accumulation_steps=____,
# Enable gradient checkpointing
____=____)
trainer = Trainer(model=model,
# Pass in the training arguments
____=____,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics)
trainer.train()