CommencerCommencer gratuitement

Boucle de formation

Enfin, tous les efforts que vous avez consacrés à la définition des architectures du modèle et des fonctions de perte portent leurs fruits : c'est l'heure de l'entraînement ! Votre mission consiste à mettre en œuvre et à exécuter la boucle d'entraînement GAN. Remarque : une instruction d'break est placée après le premier lot de données afin d'éviter un temps d'exécution trop long.

Les deux optimiseurs, disc_opt et gen_opt, ont été initialisés en tant qu'optimiseurs d'Adam(). Les fonctions permettant de calculer les pertes que vous avez définies précédemment, gen_loss() et disc_loss(), sont à votre disposition. Une déclaration de confidentialité ( dataloader ) est également mise à votre disposition.

Rappelons que :

  • disc_loss()Les arguments avancés sont les suivants : gen, disc, real, cur_batch_size, z_dim.
  • gen_loss()Les arguments avancés sont les suivants : gen, disc, cur_batch_size, z_dim.

Cet exercice fait partie du cours

Deep learning pour les images avec PyTorch

Afficher le cours

Instructions

  • Calculez la perte du discriminateur à l'aide de la fonction « disc_loss() » en lui transmettant, dans cet ordre, le générateur, le discriminateur, l'échantillon d'images réelles, la taille du lot actuel et la taille du bruit de « 16 », puis attribuez le résultat à « d_loss ».
  • Veuillez calculer les gradients à l'aide de l'd_loss.
  • Calculez la perte du générateur à l'aide de gen_loss() en lui transmettant le générateur, le discriminateur, la taille du lot actuel et la taille du bruit de 16, dans cet ordre, puis attribuez le résultat à g_loss.
  • Veuillez calculer les gradients à l'aide de l'g_loss.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Modifier et exécuter le code