MulaiMulai sekarang secara gratis

Loss discriminator

Saatnya mendefinisikan loss untuk discriminator. Ingat bahwa tugas discriminator adalah mengklasifikasikan citra sebagai real atau fake. Karena itu, generator menanggung loss jika ia mengklasifikasikan keluaran generator sebagai real (label 1) atau citra real sebagai fake (label 0).

Definisikan fungsi disc_loss() yang menghitung loss discriminator. Fungsi ini menerima lima argumen:

  • gen, model generator
  • disc, model discriminator
  • real, sampel citra real dari data pelatihan
  • num_images, jumlah citra dalam batch
  • z_dim, ukuran noise acak masukan

Latihan ini adalah bagian dari kursus

Deep Learning untuk Gambar dengan PyTorch

Lihat Kursus

Petunjuk latihan

  • Gunakan discriminator untuk mengklasifikasikan citra fake dan tetapkan prediksinya ke disc_pred_fake.
  • Hitung komponen loss untuk fake dengan memanggil criterion pada prediksi discriminator untuk citra fake dan tensor nol dengan bentuk yang sama.
  • Gunakan discriminator untuk mengklasifikasikan citra real dan tetapkan prediksinya ke disc_pred_real.
  • Hitung komponen loss untuk real dengan memanggil criterion pada prediksi discriminator untuk citra real dan tensor satu dengan bentuk yang sama.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Edit dan Jalankan Kode