Boucle d’entraînement
Enfin, tout le travail que vous avez consacré à définir les architectures de modèles et les fonctions de perte porte ses fruits : c’est le moment d’entraîner le modèle ! Votre tâche est d’implémenter et d’exécuter la boucle d’entraînement du GAN. Remarque : une instruction break est placée après le premier lot de données pour éviter un temps d’exécution trop long.
Les deux optimiseurs, disc_opt et gen_opt, ont été initialisés comme optimiseurs Adam(). Les fonctions de calcul des pertes que vous avez définies plus tôt, gen_loss() et disc_loss(), sont à votre disposition. Un dataloader est également prêt pour vous.
Rappelez-vous que :
- Les arguments de
disc_loss()sont :gen,disc,real,cur_batch_size,z_dim. - Les arguments de
gen_loss()sont :gen,disc,cur_batch_size,z_dim.
Cet exercice fait partie du cours
Deep Learning pour l’image avec PyTorch
Instructions
- Calculez la perte du discriminateur avec
disc_loss()en lui passant, dans cet ordre, le générateur, le discriminateur, l’échantillon d’images réelles, la taille de lot courante et la taille du bruit de16, puis affectez le résultat àd_loss. - Calculez les gradients en utilisant
d_loss. - Calculez la perte du générateur avec
gen_loss()en lui passant, dans cet ordre, le générateur, le discriminateur, la taille de lot courante et la taille du bruit de16, puis affectez le résultat àg_loss. - Calculez les gradients en utilisant
g_loss.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break