Mengoptimalkan pelatihan model dengan Lightning
Dengan menerapkan teknik otomatis seperti ModelCheckpoint dan EarlyStopping, Anda akan memastikan model memilih parameter berkinerja terbaik sekaligus menghindari komputasi yang tidak perlu.
Himpunan data, sebuah subset dari himpunan data Osmanya MNIST, memberikan contoh kasus nyata di mana teknik pelatihan AI yang dapat diskalakan dapat meningkatkan efisiensi dan akurasi secara signifikan.
OsmanyaDataModule dan ImageClassifier telah disiapkan untuk Anda.
Latihan ini adalah bagian dari kursus
Model AI yang Dapat Diskalakan dengan PyTorch Lightning
Petunjuk latihan
- Impor callback yang akan Anda gunakan untuk penyimpanan checkpoint model dan penghentian dini.
- Latih model dengan callback
ModelCheckpointdanEarlyStopping.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____
class EvaluatedImageClassifier(ImageClassifier):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc)
model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)