CommencerCommencer gratuitement

Perte du discriminateur

Il est temps de définir la fonction de perte du discriminateur. Rappelez-vous que le rôle du discriminateur est de classer les images comme réelles ou fictives. Par conséquent, le générateur subit une perte si le discriminateur classe les sorties du générateur comme réelles (étiquette 1) ou les images réelles comme fictives (étiquette 0).

Définissez la fonction disc_loss() qui calcule la perte du discriminateur. Elle prend cinq arguments :

  • gen, le modèle générateur
  • disc, le modèle discriminateur
  • real, un échantillon d’images réelles issues des données d’entraînement
  • num_images, le nombre d’images dans le lot
  • z_dim, la taille du bruit aléatoire en entrée

Cet exercice fait partie du cours

Deep Learning pour l’image avec PyTorch

Afficher le cours

Instructions

  • Utilisez le discriminateur pour classer les images fake et affectez les prédictions à disc_pred_fake.
  • Calculez la composante de perte pour le faux en appelant criterion sur les prédictions du discriminateur pour les images fictives et un tenseur de zéros de même forme.
  • Utilisez le discriminateur pour classer les images real et affectez les prédictions à disc_pred_real.
  • Calculez la composante de perte pour le réel en appelant criterion sur les prédictions du discriminateur pour les images réelles et un tenseur de uns de même forme.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Modifier et exécuter le code