MulaiMulai sekarang secara gratis

Loss generator

Sebelum Anda dapat melatih GAN, Anda perlu mendefinisikan fungsi loss untuk generator dan discriminator. Anda akan mulai dengan yang pertama.

Ingat bahwa tugas generator adalah menghasilkan gambar palsu yang cukup meyakinkan sehingga dapat menipu discriminator agar mengklasifikasikannya sebagai nyata. Oleh karena itu, generator akan menanggung loss jika gambar yang dihasilkannya diklasifikasikan oleh discriminator sebagai palsu (label 0).

Definisikan fungsi gen_loss() yang menghitung loss generator. Fungsi ini menerima empat argumen:

  • gen, model generator
  • disc, model discriminator
  • num_images, jumlah gambar dalam batch
  • z_dim, ukuran noise acak masukan

Latihan ini adalah bagian dari kursus

Deep Learning untuk Gambar dengan PyTorch

Lihat Kursus

Petunjuk latihan

  • Hasilkan noise acak berbentuk num_images kali z_dim dan tetapkan ke noise.
  • Gunakan generator untuk menghasilkan gambar palsu dari noise dan tetapkan ke fake.
  • Dapatkan prediksi discriminator untuk gambar palsu yang dihasilkan.
  • Hitung loss generator dengan memanggil criterion pada prediksi discriminator dan tensor berisi angka satu dengan bentuk yang sama.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Edit dan Jalankan Kode