Loss discriminator
Saatnya mendefinisikan loss untuk discriminator. Ingat bahwa tugas discriminator adalah mengklasifikasikan citra sebagai real atau fake. Karena itu, generator menanggung loss jika ia mengklasifikasikan keluaran generator sebagai real (label 1) atau citra real sebagai fake (label 0).
Definisikan fungsi disc_loss() yang menghitung loss discriminator. Fungsi ini menerima lima argumen:
gen, model generatordisc, model discriminatorreal, sampel citra real dari data pelatihannum_images, jumlah citra dalam batchz_dim, ukuran noise acak masukan
Latihan ini adalah bagian dari kursus
Deep Learning untuk Gambar dengan PyTorch
Petunjuk latihan
- Gunakan discriminator untuk mengklasifikasikan citra
fakedan tetapkan prediksinya kedisc_pred_fake. - Hitung komponen loss untuk fake dengan memanggil
criterionpada prediksi discriminator untuk citra fake dan tensor nol dengan bentuk yang sama. - Gunakan discriminator untuk mengklasifikasikan citra
realdan tetapkan prediksinya kedisc_pred_real. - Hitung komponen loss untuk real dengan memanggil
criterionpada prediksi discriminator untuk citra real dan tensor satu dengan bentuk yang sama.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss