IniziaInizia gratis

Ottimizzare l'addestramento del modello con Lightning

Implementando tecniche automatiche come ModelCheckpoint ed EarlyStopping, farai sì che il tuo modello selezioni i parametri con le prestazioni migliori evitando calcoli inutili.

Il dataset, un sottoinsieme dell'Osmanya MNIST, offre un caso d'uso reale in cui tecniche di training scalabili possono migliorare significativamente efficienza e accuratezza.

OsmanyaDataModule e ImageClassifier sono già stati predefiniti per te.

Questo esercizio fa parte del corso

Modelli di AI scalabili con PyTorch Lightning

Visualizza il corso

Istruzioni dell'esercizio

  • Importa i callback che userai per il salvataggio dei checkpoint del modello e per l'early stopping.
  • Addestra il modello con i callback ModelCheckpoint ed EarlyStopping.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____

class EvaluatedImageClassifier(ImageClassifier):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc)

model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)
Modifica ed esegui il codice