Perte de discriminateur
Il est temps de définir la perte pour le discriminateur. Rappelons que le rôle du discriminateur est de classer les images en deux catégories : réelles ou fausses. Par conséquent, le générateur subit une perte s'il classe les sorties du générateur comme réelles (étiquette « 1
») ou les images réelles comme fausses (étiquette « 0
»).
Définissez la fonction d'disc_loss()
qui calcule la perte du discriminateur. Il nécessite cinq arguments :
gen
, le modèle de générateurdisc
, le modèle discriminateurreal
, un échantillon d'images réelles issues des données d'entraînementnum_images
, le nombre d'images dans le lotz_dim
, la taille du bruit aléatoire d'entrée
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Utilisez le discriminateur pour classer les images d'
fake
et attribuer les prédictions àdisc_pred_fake
. - Calculez la composante de perte fictive en appelant «
criterion
» sur les prédictions du discriminateur pour les images fictives et le tenseur de zéros de même forme. - Utilisez le discriminateur pour classer les images d'
real
et attribuer les prédictions àdisc_pred_real
. - Calculez la composante de perte réelle en appelant
criterion
sur les prédictions du discriminateur pour les images réelles et le tenseur de valeurs 1 de même forme.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss