CommencerCommencer gratuitement

Boucle d’entraînement

Enfin, tout le travail que vous avez consacré à définir les architectures de modèles et les fonctions de perte porte ses fruits : c’est le moment d’entraîner le modèle ! Votre tâche est d’implémenter et d’exécuter la boucle d’entraînement du GAN. Remarque : une instruction break est placée après le premier lot de données pour éviter un temps d’exécution trop long.

Les deux optimiseurs, disc_opt et gen_opt, ont été initialisés comme optimiseurs Adam(). Les fonctions de calcul des pertes que vous avez définies plus tôt, gen_loss() et disc_loss(), sont à votre disposition. Un dataloader est également prêt pour vous.

Rappelez-vous que :

  • Les arguments de disc_loss() sont : gen, disc, real, cur_batch_size, z_dim.
  • Les arguments de gen_loss() sont : gen, disc, cur_batch_size, z_dim.

Cet exercice fait partie du cours

Deep Learning pour l’image avec PyTorch

Afficher le cours

Instructions

  • Calculez la perte du discriminateur avec disc_loss() en lui passant, dans cet ordre, le générateur, le discriminateur, l’échantillon d’images réelles, la taille de lot courante et la taille du bruit de 16, puis affectez le résultat à d_loss.
  • Calculez les gradients en utilisant d_loss.
  • Calculez la perte du générateur avec gen_loss() en lui passant, dans cet ordre, le générateur, le discriminateur, la taille de lot courante et la taille du bruit de 16, puis affectez le résultat à g_loss.
  • Calculez les gradients en utilisant g_loss.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Modifier et exécuter le code