Mengimplementasikan langkah pelatihan
Dalam latihan ini, Anda akan mengimplementasikan metode training_step() dalam modul PyTorch Lightning yang dirancang untuk tugas klasifikasi citra.
Implementasi Anda harus membongkar satu batch gambar dan label, menghitung prediksi model melalui forward pass, menghitung cross entropy loss, serta mencatat (log) training loss.
Latihan ini adalah bagian dari kursus
Model AI yang Dapat Diskalakan dengan PyTorch Lightning
Petunjuk latihan
- Pastikan Anda menghitung prediksi menggunakan forward pass.
- Hitung cross entropy loss.
- Catat (log) training loss.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
from torch.nn.functional import cross_entropy
def training_step(self, batch, batch_idx):
x, y = batch
# Ensure that you compute predictions using the forward pass
y_hat = ____
# Calculate the cross entropy loss
loss = ____
# Log the loss
self.____("train_loss", loss)
return loss