Perte du discriminateur
Il est temps de définir la fonction de perte du discriminateur. Rappelez-vous que le rôle du discriminateur est de classer les images comme réelles ou fictives. Par conséquent, le générateur subit une perte si le discriminateur classe les sorties du générateur comme réelles (étiquette 1) ou les images réelles comme fictives (étiquette 0).
Définissez la fonction disc_loss() qui calcule la perte du discriminateur. Elle prend cinq arguments :
gen, le modèle générateurdisc, le modèle discriminateurreal, un échantillon d’images réelles issues des données d’entraînementnum_images, le nombre d’images dans le lotz_dim, la taille du bruit aléatoire en entrée
Cet exercice fait partie du cours
Deep Learning pour l’image avec PyTorch
Instructions
- Utilisez le discriminateur pour classer les images
fakeet affectez les prédictions àdisc_pred_fake. - Calculez la composante de perte pour le faux en appelant
criterionsur les prédictions du discriminateur pour les images fictives et un tenseur de zéros de même forme. - Utilisez le discriminateur pour classer les images
realet affectez les prédictions àdisc_pred_real. - Calculez la composante de perte pour le réel en appelant
criterionsur les prédictions du discriminateur pour les images réelles et un tenseur de uns de même forme.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss