Discriminator loss
It's time to define the loss for the discriminator. Recall that the discriminator's job is to classify images either real or fake. Therefore, the generator incurs a loss if it classifies generator's outputs as real (label 1
) or the real images as fake (label 0
).
Define the disc_loss()
function that calculates the discriminator loss. It takes five arguments:
gen
, the generator modeldisc
, the discriminator modelreal
, a sample of real images from the training datanum_images
, the number of images in batchz_dim
, the size of the input random noise
This exercise is part of the course
Deep Learning for Images with PyTorch
Exercise instructions
- Use the discriminator to classify
fake
images and assign the predictions todisc_pred_fake
. - Compute the fake loss component by calling
criterion
on discriminator's predictions for fake images and the a tensor of zeros of the same shape. - Use the discriminator to classify
real
images and assign the predictions todisc_pred_real
. - Compute the real loss component by calling
criterion
on discriminator's predictions for real images and the a tensor of ones of the same shape.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss