Aan de slagGa gratis aan de slag

Training van modellen optimaliseren met Lightning

Door geautomatiseerde technieken zoals ModelCheckpoint en EarlyStopping toe te passen, zorg je ervoor dat je model de best presterende parameters kiest en onnodige berekeningen vermijdt.

De gegevensset, een subset van de Osmanya MNIST-gegevensset, biedt een praktijkvoorbeeld waarbij schaalbare AI-trainings­technieken de efficiëntie en nauwkeurigheid aanzienlijk kunnen verbeteren.

OsmanyaDataModule en ImageClassifier zijn al voor je gedefinieerd.

Deze oefening maakt deel uit van de cursus

Schaalbare AI-modellen met PyTorch Lightning

Cursus bekijken

Oefeninstructies

  • Importeer callbacks die je gebruikt voor model check-pointing en early stopping.
  • Train het model met de callbacks ModelCheckpoint en EarlyStopping.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____

class EvaluatedImageClassifier(ImageClassifier):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc)

model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)
Code bewerken en uitvoeren