Get startedGet started for free

Generator loss

Before you can train your GAN, you need to define loss functions for both the generator and the discriminator. You will start with the former.

Recall that the generator's job is to produce such fake images that would fool the discriminator into classifying them as real. Therefore, the generator incurs a loss if the images it generated are classified by the discriminator as fake (label 0).

Define the gen_loss() function that calculates the generator loss. It takes four arguments:

  • gen, the generator model
  • disc, the discriminator model
  • num_images, the number of images in batch
  • z_dim, the size of the input random noise

This exercise is part of the course

Deep Learning for Images with PyTorch

View Course

Exercise instructions

  • Generate random noise of shape num_images by z_dim and assign it to noise.
  • Use the generator to generate a fake image from for noise and assign it to fake.
  • Get discriminator's prediction for the generated fake image.
  • Compute generators loss by calling criterion on discriminator's predictions and the a tensor of ones of the same shape.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Edit and Run Code