Ottimizzare l'addestramento del modello con Lightning
Implementando tecniche automatiche come ModelCheckpoint ed EarlyStopping, farai sì che il tuo modello selezioni i parametri con le prestazioni migliori evitando calcoli inutili.
Il dataset, un sottoinsieme dell'Osmanya MNIST, offre un caso d'uso reale in cui tecniche di training scalabili possono migliorare significativamente efficienza e accuratezza.
OsmanyaDataModule e ImageClassifier sono già stati predefiniti per te.
Questo esercizio fa parte del corso
Modelli di AI scalabili con PyTorch Lightning
Istruzioni dell'esercizio
- Importa i callback che userai per il salvataggio dei checkpoint del modello e per l'early stopping.
- Addestra il modello con i callback
ModelCheckpointedEarlyStopping.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)