Perte du générateur
Avant de pouvoir entraîner votre GAN, vous devez définir des fonctions de perte pour le générateur et le discriminateur. Vous commencerez par le premier.
Rappelons que le rôle du générateur est de produire de fausses images qui trompent le discriminateur et l'amènent à les classer comme réelles. Par conséquent, le générateur subit une perte si les images qu'il a générées sont classées comme fausses par le discriminateur (étiquette 0).
Définissez la fonction d'gen_loss() qui calcule la perte du générateur. Il prend quatre arguments :
- gen, le modèle de générateur
- disc, le modèle discriminateur
- num_images, le nombre d'images dans le lot
- z_dim, la taille du bruit aléatoire d'entrée
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Générez un bruit aléatoire de forme num_imagesà l'adressez_dimet attribuez-le ànoise.
- Veuillez utiliser le générateur pour créer une image fictive à partir de noiseet l'attribuer àfake.
- Obtenez la prédiction du discriminateur pour l'image falsifiée générée.
- Calculez la perte des générateurs en appelant « criterion» sur les prédictions du discriminateur et le tenseur de valeurs «1» de même forme.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss