IniziaInizia gratis

Generator loss

Before you can train your GAN, you need to define loss functions for both the generator and the discriminator. You will start with the former.

Recall that the generator's job is to produce such fake images that would fool the discriminator into classifying them as real. Therefore, the generator incurs a loss if the images it generated are classified by the discriminator as fake (label 0).

Define the gen_loss() function that calculates the generator loss. It takes four arguments:

  • gen, the generator model
  • disc, the discriminator model
  • num_images, the number of images in batch
  • z_dim, the size of the input random noise

Questo esercizio fa parte del corso

Deep Learning for Images with PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Generate random noise of shape num_images by z_dim and assign it to noise.
  • Use the generator to generate a fake image from for noise and assign it to fake.
  • Get discriminator's prediction for the generated fake image.
  • Compute generators loss by calling criterion on discriminator's predictions and the a tensor of ones of the same shape.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Modifica ed esegui il codice