Otimizando o treinamento do modelo com o Lightning
Ao usar técnicas automáticas como ModelCheckpoint
e EarlyStopping
, você vai garantir que seu modelo escolha os parâmetros que funcionam melhor, sem precisar fazer cálculos desnecessários.
O conjunto de dados, que é uma parte do conjunto de dados Osmanya MNIST, mostra um caso real em que técnicas de treinamento de IA que podem ser usadas em grande escala melhoram bastante a eficiência e a precisão.
OsmanyaDataModule
e ImageClassifier
já estão definidos pra você.
Este exercício faz parte do curso
Modelos de IA escaláveis com PyTorch Lightning
Instruções do exercício
- Importa os callbacks que você vai usar para verificar pontos de verificação do modelo e interrupção antecipada.
- Treine o modelo com as chamadas de retorno “
ModelCheckpoint
” e “EarlyStopping
”.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)