IniziaInizia gratis

Ciclo di training

Finalmente, tutto il lavoro fatto per definire le architetture dei modelli e le funzioni di loss dà i suoi frutti: è il momento del training! Il tuo compito è implementare ed eseguire il ciclo di training della GAN. Nota: un'istruzione break è inserita dopo il primo batch di dati per evitare un tempo di esecuzione troppo lungo.

I due ottimizzatori, disc_opt e gen_opt, sono stati inizializzati come ottimizzatori Adam(). Le funzioni per calcolare le loss che hai definito prima, gen_loss() e disc_loss(), sono a tua disposizione. È stato preparato anche un dataloader.

Ricorda che:

  • Gli argomenti di disc_loss() sono: gen, disc, real, cur_batch_size, z_dim.
  • Gli argomenti di gen_loss() sono: gen, disc, cur_batch_size, z_dim.

Questo esercizio fa parte del corso

Deep Learning per Immagini con PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Calcola la loss del discriminatore usando disc_loss() passandogli, in quest'ordine, il generatore, il discriminatore, il campione di immagini reali, la dimensione del batch corrente e la dimensione del rumore pari a 16, e assegna il risultato a d_loss.
  • Calcola i gradienti usando d_loss.
  • Calcola la loss del generatore usando gen_loss() passandogli, in quest'ordine, il generatore, il discriminatore, la dimensione del batch corrente e la dimensione del rumore pari a 16, e assegna il risultato a g_loss.
  • Calcola i gradienti usando g_loss.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Modifica ed esegui il codice