Get startedGet started for free

Discriminator loss

It's time to define the loss for the discriminator. Recall that the discriminator's job is to classify images either real or fake. Therefore, the generator incurs a loss if it classifies generator's outputs as real (label 1) or the real images as fake (label 0).

Define the disc_loss() function that calculates the discriminator loss. It takes five arguments:

  • gen, the generator model
  • disc, the discriminator model
  • real, a sample of real images from the training data
  • num_images, the number of images in batch
  • z_dim, the size of the input random noise

This exercise is part of the course

Deep Learning for Images with PyTorch

View Course

Exercise instructions

  • Use the discriminator to classify fake images and assign the predictions to disc_pred_fake.
  • Compute the fake loss component by calling criterion on discriminator's predictions for fake images and the a tensor of zeros of the same shape.
  • Use the discriminator to classify real images and assign the predictions to disc_pred_real.
  • Compute the real loss component by calling criterion on discriminator's predictions for real images and the a tensor of ones of the same shape.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Edit and Run Code