Optimización del entrenamiento de modelos con Lightning
Al implementar técnicas automatizadas como ModelCheckpoint
y EarlyStopping
, te asegurarás de que tu modelo seleccione los parámetros con mejor rendimiento y evites cálculos innecesarios.
El conjunto de datos, un subconjunto del conjunto de datos Osmanya MNIST, proporciona un caso de uso real en el que las técnicas de entrenamiento de IA escalables pueden mejorar significativamente la eficiencia y la precisión.
OsmanyaDataModule
ImageClassifier
han sido predefinidos para ti.
Este ejercicio forma parte del curso
Modelos de IA escalables con PyTorch Lightning
Instrucciones del ejercicio
- Importa las devoluciones de llamada que utilizarás para el control de puntos de verificación del modelo y la detención temprana.
- Entrena el modelo con las funciones de devolución de llamada «
ModelCheckpoint
» y «EarlyStopping
».
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)