Perte de discriminateur
Il est temps de définir la perte pour le discriminateur. Rappelons que le rôle du discriminateur est de classer les images en deux catégories : réelles ou fausses. Par conséquent, le générateur subit une perte s'il classe les sorties du générateur comme réelles (étiquette « 1 ») ou les images réelles comme fausses (étiquette « 0 »).
Définissez la fonction d'disc_loss() qui calcule la perte du discriminateur. Il nécessite cinq arguments :
- gen, le modèle de générateur
- disc, le modèle discriminateur
- real, un échantillon d'images réelles issues des données d'entraînement
- num_images, le nombre d'images dans le lot
- z_dim, la taille du bruit aléatoire d'entrée
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Utilisez le discriminateur pour classer les images d'fakeet attribuer les prédictions àdisc_pred_fake.
- Calculez la composante de perte fictive en appelant « criterion» sur les prédictions du discriminateur pour les images fictives et le tenseur de zéros de même forme.
- Utilisez le discriminateur pour classer les images d'realet attribuer les prédictions àdisc_pred_real.
- Calculez la composante de perte réelle en appelant criterionsur les prédictions du discriminateur pour les images réelles et le tenseur de valeurs 1 de même forme.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss