Training van modellen optimaliseren met Lightning
Door geautomatiseerde technieken zoals ModelCheckpoint en EarlyStopping toe te passen, zorg je ervoor dat je model de best presterende parameters kiest en onnodige berekeningen vermijdt.
De gegevensset, een subset van de Osmanya MNIST-gegevensset, biedt een praktijkvoorbeeld waarbij schaalbare AI-trainingstechnieken de efficiëntie en nauwkeurigheid aanzienlijk kunnen verbeteren.
OsmanyaDataModule en ImageClassifier zijn al voor je gedefinieerd.
Deze oefening maakt deel uit van de cursus
Schaalbare AI-modellen met PyTorch Lightning
Oefeninstructies
- Importeer callbacks die je gebruikt voor model check-pointing en early stopping.
- Train het model met de callbacks
ModelCheckpointenEarlyStopping.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)