CommencerCommencer gratuitement

Perte du générateur

Avant de pouvoir entraîner votre GAN, vous devez définir des fonctions de perte pour le générateur et le discriminateur. Vous commencerez par le premier.

Rappelons que le rôle du générateur est de produire de fausses images qui trompent le discriminateur et l'amènent à les classer comme réelles. Par conséquent, le générateur subit une perte si les images qu'il a générées sont classées comme fausses par le discriminateur (étiquette 0).

Définissez la fonction d'gen_loss() qui calcule la perte du générateur. Il prend quatre arguments :

  • gen, le modèle de générateur
  • disc, le modèle discriminateur
  • num_images, le nombre d'images dans le lot
  • z_dim, la taille du bruit aléatoire d'entrée

Cet exercice fait partie du cours

Deep learning pour les images avec PyTorch

Afficher le cours

Instructions

  • Générez un bruit aléatoire de forme num_images à l'adresse z_dim et attribuez-le à noise.
  • Veuillez utiliser le générateur pour créer une image fictive à partir de noise et l'attribuer à fake.
  • Obtenez la prédiction du discriminateur pour l'image falsifiée générée.
  • Calculez la perte des générateurs en appelant « criterion » sur les prédictions du discriminateur et le tenseur de valeurs «1» de même forme.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Modifier et exécuter le code