MulaiMulai sekarang secara gratis

Mengimplementasikan langkah pelatihan

Dalam latihan ini, Anda akan mengimplementasikan metode training_step() dalam modul PyTorch Lightning yang dirancang untuk tugas klasifikasi citra. Implementasi Anda harus membongkar satu batch gambar dan label, menghitung prediksi model melalui forward pass, menghitung cross entropy loss, serta mencatat (log) training loss.

Latihan ini adalah bagian dari kursus

Model AI yang Dapat Diskalakan dengan PyTorch Lightning

Lihat Kursus

Petunjuk latihan

  • Pastikan Anda menghitung prediksi menggunakan forward pass.
  • Hitung cross entropy loss.
  • Catat (log) training loss.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

from torch.nn.functional import cross_entropy

def training_step(self, batch, batch_idx):
    x, y = batch
    # Ensure that you compute predictions using the forward pass
    y_hat = ____
    # Calculate the cross entropy loss
    loss = ____
    # Log the loss
    self.____("train_loss", loss)
    return loss
Edit dan Jalankan Kode