MulaiMulai sekarang secara gratis

Mengoptimalkan pelatihan model dengan Lightning

Dengan menerapkan teknik otomatis seperti ModelCheckpoint dan EarlyStopping, Anda akan memastikan model memilih parameter berkinerja terbaik sekaligus menghindari komputasi yang tidak perlu.

Himpunan data, sebuah subset dari himpunan data Osmanya MNIST, memberikan contoh kasus nyata di mana teknik pelatihan AI yang dapat diskalakan dapat meningkatkan efisiensi dan akurasi secara signifikan.

OsmanyaDataModule dan ImageClassifier telah disiapkan untuk Anda.

Latihan ini adalah bagian dari kursus

Model AI yang Dapat Diskalakan dengan PyTorch Lightning

Lihat Kursus

Petunjuk latihan

  • Impor callback yang akan Anda gunakan untuk penyimpanan checkpoint model dan penghentian dini.
  • Latih model dengan callback ModelCheckpoint dan EarlyStopping.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

# Import relevant checkpoints
from lightning.pytorch.callbacks import ____, ____

class EvaluatedImageClassifier(ImageClassifier):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc)

model = EvaluatedImageClassifier()
data_module = OsmanyaDataModule()
# Train the model with ModelCheckpoint and EarlyStopping checkpoints
trainer = Trainer(____=[____(monitor="val_acc", save_top_k=1), ____(monitor="val_acc", patience=3)])
trainer.fit(model, datamodule=data_module)
Edit dan Jalankan Kode