Loss generator
Sebelum Anda dapat melatih GAN, Anda perlu mendefinisikan fungsi loss untuk generator dan discriminator. Anda akan mulai dengan yang pertama.
Ingat bahwa tugas generator adalah menghasilkan gambar palsu yang cukup meyakinkan sehingga dapat menipu discriminator agar mengklasifikasikannya sebagai nyata. Oleh karena itu, generator akan menanggung loss jika gambar yang dihasilkannya diklasifikasikan oleh discriminator sebagai palsu (label 0).
Definisikan fungsi gen_loss() yang menghitung loss generator. Fungsi ini menerima empat argumen:
gen, model generatordisc, model discriminatornum_images, jumlah gambar dalam batchz_dim, ukuran noise acak masukan
Latihan ini adalah bagian dari kursus
Deep Learning untuk Gambar dengan PyTorch
Petunjuk latihan
- Hasilkan noise acak berbentuk
num_imageskaliz_dimdan tetapkan kenoise. - Gunakan generator untuk menghasilkan gambar palsu dari
noisedan tetapkan kefake. - Dapatkan prediksi discriminator untuk gambar palsu yang dihasilkan.
- Hitung loss generator dengan memanggil
criterionpada prediksi discriminator dan tensor berisi angka satu dengan bentuk yang sama.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss