Optimizing model training with Lightning
By implementing automated techniques like ModelCheckpoint
and EarlyStopping
, you'll ensure your model selects the best-performing parameters while avoiding unnecessary computations.
The dataset, a subset of the Osmanya MNIST dataset, provides a real-world use case where scalable AI training techniques can significantly improve efficiency and accuracy.
OsmanyaDataModule
and ImageClassifier
have been predefined for you.
This exercise is part of the course
Scalable AI Models with PyTorch Lightning
Exercise instructions
- Import callbacks that you'll use for model checkpointing and early stopping.
- Train the model with the
ModelCheckpoint
andEarlyStopping
callbacks.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)