Perte du générateur
Avant d’entraîner votre GAN, vous devez définir des fonctions de perte pour le générateur et le discriminateur. Vous allez commencer par la première.
Rappelez-vous que le rôle du générateur est de produire de fausses images capables de tromper le discriminateur, afin qu’il les classe comme réelles. Par conséquent, le générateur subit une perte si les images qu’il a générées sont classées comme fausses (étiquette 0) par le discriminateur.
Définissez la fonction gen_loss() qui calcule la perte du générateur. Elle prend quatre arguments :
gen, le modèle générateurdisc, le modèle discriminateurnum_images, le nombre d’images dans le lotz_dim, la taille du bruit aléatoire en entrée
Cet exercice fait partie du cours
Deep Learning pour l’image avec PyTorch
Instructions
- Générez du bruit aléatoire de forme
num_imagesparz_dimet affectez-le ànoise. - Utilisez le générateur pour produire une fausse image à partir de
noiseet affectez-la àfake. - Obtenez la prédiction du discriminateur pour l’image factice générée.
- Calculez la perte du générateur en appelant
criterionsur les prédictions du discriminateur et un tenseur de uns de même forme.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss