Perte du générateur
Avant de pouvoir entraîner votre GAN, vous devez définir des fonctions de perte pour le générateur et le discriminateur. Vous commencerez par le premier.
Rappelons que le rôle du générateur est de produire de fausses images qui trompent le discriminateur et l'amènent à les classer comme réelles. Par conséquent, le générateur subit une perte si les images qu'il a générées sont classées comme fausses par le discriminateur (étiquette 0
).
Définissez la fonction d'gen_loss()
qui calcule la perte du générateur. Il prend quatre arguments :
gen
, le modèle de générateurdisc
, le modèle discriminateurnum_images
, le nombre d'images dans le lotz_dim
, la taille du bruit aléatoire d'entrée
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Générez un bruit aléatoire de forme
num_images
à l'adressez_dim
et attribuez-le ànoise
. - Veuillez utiliser le générateur pour créer une image fictive à partir de
noise
et l'attribuer àfake
. - Obtenez la prédiction du discriminateur pour l'image falsifiée générée.
- Calculez la perte des générateurs en appelant «
criterion
» sur les prédictions du discriminateur et le tenseur de valeurs «1» de même forme.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss