CommencerCommencer gratuitement

Perte du générateur

Avant d’entraîner votre GAN, vous devez définir des fonctions de perte pour le générateur et le discriminateur. Vous allez commencer par la première.

Rappelez-vous que le rôle du générateur est de produire de fausses images capables de tromper le discriminateur, afin qu’il les classe comme réelles. Par conséquent, le générateur subit une perte si les images qu’il a générées sont classées comme fausses (étiquette 0) par le discriminateur.

Définissez la fonction gen_loss() qui calcule la perte du générateur. Elle prend quatre arguments :

  • gen, le modèle générateur
  • disc, le modèle discriminateur
  • num_images, le nombre d’images dans le lot
  • z_dim, la taille du bruit aléatoire en entrée

Cet exercice fait partie du cours

Deep Learning pour l’image avec PyTorch

Afficher le cours

Instructions

  • Générez du bruit aléatoire de forme num_images par z_dim et affectez-le à noise.
  • Utilisez le générateur pour produire une fausse image à partir de noise et affectez-la à fake.
  • Obtenez la prédiction du discriminateur pour l’image factice générée.
  • Calculez la perte du générateur en appelant criterion sur les prédictions du discriminateur et un tenseur de uns de même forme.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Modifier et exécuter le code