CommencerCommencer gratuitement

Perte de discriminateur

Il est temps de définir la perte pour le discriminateur. Rappelons que le rôle du discriminateur est de classer les images en deux catégories : réelles ou fausses. Par conséquent, le générateur subit une perte s'il classe les sorties du générateur comme réelles (étiquette « 1 ») ou les images réelles comme fausses (étiquette « 0 »).

Définissez la fonction d'disc_loss() qui calcule la perte du discriminateur. Il nécessite cinq arguments :

  • gen, le modèle de générateur
  • disc, le modèle discriminateur
  • real, un échantillon d'images réelles issues des données d'entraînement
  • num_images, le nombre d'images dans le lot
  • z_dim, la taille du bruit aléatoire d'entrée

Cet exercice fait partie du cours

Deep learning pour les images avec PyTorch

Afficher le cours

Instructions

  • Utilisez le discriminateur pour classer les images d'fake et attribuer les prédictions à disc_pred_fake.
  • Calculez la composante de perte fictive en appelant « criterion » sur les prédictions du discriminateur pour les images fictives et le tenseur de zéros de même forme.
  • Utilisez le discriminateur pour classer les images d'real et attribuer les prédictions à disc_pred_real.
  • Calculez la composante de perte réelle en appelant criterion sur les prédictions du discriminateur pour les images réelles et le tenseur de valeurs 1 de même forme.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Modifier et exécuter le code