Generator loss
Before you can train your GAN, you need to define loss functions for both the generator and the discriminator. You will start with the former.
Recall that the generator's job is to produce such fake images that would fool the discriminator into classifying them as real. Therefore, the generator incurs a loss if the images it generated are classified by the discriminator as fake (label 0
).
Define the gen_loss()
function that calculates the generator loss. It takes four arguments:
gen
, the generator modeldisc
, the discriminator modelnum_images
, the number of images in batchz_dim
, the size of the input random noise
This exercise is part of the course
Deep Learning for Images with PyTorch
Exercise instructions
- Generate random noise of shape
num_images
byz_dim
and assign it tonoise
. - Use the generator to generate a fake image from for
noise
and assign it tofake
. - Get discriminator's prediction for the generated fake image.
- Compute generators loss by calling
criterion
on discriminator's predictions and the a tensor of ones of the same shape.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss