Mulai sekarangMulai gratis

Mengimplementasikan langkah pelatihan

Dalam latihan ini, Anda akan mengimplementasikan metode training_step() dalam modul PyTorch Lightning yang dirancang untuk tugas klasifikasi citra. Implementasi Anda harus membongkar satu batch gambar dan label, menghitung prediksi model melalui forward pass, menghitung cross entropy loss, serta mencatat (log) training loss.

Latihan ini merupakan bagian dari kursus

Model AI yang Dapat Diskalakan dengan PyTorch Lightning

Lihat Kursus

Instruksi latihan

  • Pastikan Anda menghitung prediksi menggunakan forward pass.
  • Hitung cross entropy loss.
  • Catat (log) training loss.

Latihan interaktif langsung praktik

Cobalah latihan ini dengan melengkapi kode contoh ini.

from torch.nn.functional import cross_entropy

def training_step(self, batch, batch_idx):
    x, y = batch
    # Ensure that you compute predictions using the forward pass
    y_hat = ____
    # Calculate the cross entropy loss
    loss = ____
    # Log the loss
    self.____("train_loss", loss)
    return loss
Edit dan Jalankan Kode