1. Learn
  2. /
  3. Courses
  4. /
  5. Deep Learning for Images with PyTorch

Connected

Exercise

Generator loss

Before you can train your GAN, you need to define loss functions for both the generator and the discriminator. You will start with the former.

Recall that the generator's job is to produce such fake images that would fool the discriminator into classifying them as real. Therefore, the generator incurs a loss if the images it generated are classified by the discriminator as fake (label 0).

Define the gen_loss() function that calculates the generator loss. It takes four arguments:

  • gen, the generator model
  • disc, the discriminator model
  • num_images, the number of images in batch
  • z_dim, the size of the input random noise

Instructions

100 XP
  • Generate random noise of shape num_images by z_dim and assign it to noise.
  • Use the generator to generate a fake image from for noise and assign it to fake.
  • Get discriminator's prediction for the generated fake image.
  • Compute generators loss by calling criterion on discriminator's predictions and the a tensor of ones of the same shape.