MulaiMulai sekarang secara gratis

Discriminator loss

It's time to define the loss for the discriminator. Recall that the discriminator's job is to classify images either real or fake. Therefore, the generator incurs a loss if it classifies generator's outputs as real (label 1) or the real images as fake (label 0).

Define the disc_loss() function that calculates the discriminator loss. It takes five arguments:

  • gen, the generator model
  • disc, the discriminator model
  • real, a sample of real images from the training data
  • num_images, the number of images in batch
  • z_dim, the size of the input random noise

Latihan ini adalah bagian dari kursus

Deep Learning for Images with PyTorch

Lihat Kursus

Petunjuk latihan

  • Use the discriminator to classify fake images and assign the predictions to disc_pred_fake.
  • Compute the fake loss component by calling criterion on discriminator's predictions for fake images and the a tensor of zeros of the same shape.
  • Use the discriminator to classify real images and assign the predictions to disc_pred_real.
  • Compute the real loss component by calling criterion on discriminator's predictions for real images and the a tensor of ones of the same shape.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Edit dan Jalankan Kode