Mengimplementasikan langkah pelatihan
Dalam latihan ini, Anda akan mengimplementasikan metode training_step() dalam modul PyTorch Lightning yang dirancang untuk tugas klasifikasi citra.
Implementasi Anda harus membongkar satu batch gambar dan label, menghitung prediksi model melalui forward pass, menghitung cross entropy loss, serta mencatat (log) training loss.
Latihan ini merupakan bagian dari kursus
Model AI yang Dapat Diskalakan dengan PyTorch Lightning
Instruksi latihan
- Pastikan Anda menghitung prediksi menggunakan forward pass.
- Hitung cross entropy loss.
- Catat (log) training loss.
Latihan interaktif langsung praktik
Cobalah latihan ini dengan melengkapi kode contoh ini.
from torch.nn.functional import cross_entropy
def training_step(self, batch, batch_idx):
x, y = batch
# Ensure that you compute predictions using the forward pass
y_hat = ____
# Calculate the cross entropy loss
loss = ____
# Log the loss
self.____("train_loss", loss)
return loss