BaşlayınÜcretsiz Başlayın

Discriminator loss

It's time to define the loss for the discriminator. Recall that the discriminator's job is to classify images either real or fake. Therefore, the generator incurs a loss if it classifies generator's outputs as real (label 1) or the real images as fake (label 0).

Define the disc_loss() function that calculates the discriminator loss. It takes five arguments:

  • gen, the generator model
  • disc, the discriminator model
  • real, a sample of real images from the training data
  • num_images, the number of images in batch
  • z_dim, the size of the input random noise

Bu egzersiz

Deep Learning for Images with PyTorch

kursunun bir parçasıdır
Kursu Görüntüle

Egzersiz talimatları

  • Use the discriminator to classify fake images and assign the predictions to disc_pred_fake.
  • Compute the fake loss component by calling criterion on discriminator's predictions for fake images and the a tensor of zeros of the same shape.
  • Use the discriminator to classify real images and assign the predictions to disc_pred_real.
  • Compute the real loss component by calling criterion on discriminator's predictions for real images and the a tensor of ones of the same shape.

Uygulamalı interaktif egzersiz

Bu örnek kodu tamamlayarak bu egzersizi bitirin.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Kodu Düzenle ve Çalıştır