Boucle de formation
Enfin, tous les efforts que vous avez consacrés à la définition des architectures du modèle et des fonctions de perte portent leurs fruits : c'est l'heure de l'entraînement ! Votre mission consiste à mettre en œuvre et à exécuter la boucle d'entraînement GAN. Remarque : une instruction d'break
est placée après le premier lot de données afin d'éviter un temps d'exécution trop long.
Les deux optimiseurs, disc_opt
et gen_opt
, ont été initialisés en tant qu'optimiseurs d'Adam()
. Les fonctions permettant de calculer les pertes que vous avez définies précédemment, gen_loss()
et disc_loss()
, sont à votre disposition. Une déclaration de confidentialité ( dataloader
) est également mise à votre disposition.
Rappelons que :
disc_loss()
Les arguments avancés sont les suivants :gen
,disc
,real
,cur_batch_size
,z_dim
.gen_loss()
Les arguments avancés sont les suivants :gen
,disc
,cur_batch_size
,z_dim
.
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Calculez la perte du discriminateur à l'aide de la fonction «
disc_loss()
» en lui transmettant, dans cet ordre, le générateur, le discriminateur, l'échantillon d'images réelles, la taille du lot actuel et la taille du bruit de «16
», puis attribuez le résultat à «d_loss
». - Veuillez calculer les gradients à l'aide de l'
d_loss
. - Calculez la perte du générateur à l'aide de
gen_loss()
en lui transmettant le générateur, le discriminateur, la taille du lot actuel et la taille du bruit de16
, dans cet ordre, puis attribuez le résultat àg_loss
. - Veuillez calculer les gradients à l'aide de l'
g_loss
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break