Optimisation de l'entraînement des modèles avec Lightning
En mettant en œuvre des techniques automatisées telles que ModelCheckpoint
et EarlyStopping
, vous vous assurerez que votre modèle sélectionne les paramètres les plus performants tout en évitant les calculs inutiles.
Cet ensemble de données, qui est un sous-ensemble de l'ensemble de données Osmanya MNIST, fournit un cas d'utilisation concret où les techniques d'entraînement de l'IA évolutives peuvent améliorer considérablement l'efficacité et la précision.
OsmanyaDataModule
ImageClassifier
ont été prédéfinis pour vous.
Cet exercice fait partie du cours
Modèles d'IA évolutifs avec PyTorch Lightning
Instructions
- Importez les rappels que vous utiliserez pour le pointage du modèle et l'arrêt anticipé.
- Entraînez le modèle à l'aide des rappels «
ModelCheckpoint
» et «EarlyStopping
».
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)