Ciclo di training
Finalmente, tutto il lavoro fatto per definire le architetture dei modelli e le funzioni di loss dà i suoi frutti: è il momento del training! Il tuo compito è implementare ed eseguire il ciclo di training della GAN. Nota: un'istruzione break è inserita dopo il primo batch di dati per evitare un tempo di esecuzione troppo lungo.
I due ottimizzatori, disc_opt e gen_opt, sono stati inizializzati come ottimizzatori Adam(). Le funzioni per calcolare le loss che hai definito prima, gen_loss() e disc_loss(), sono a tua disposizione. È stato preparato anche un dataloader.
Ricorda che:
- Gli argomenti di
disc_loss()sono:gen,disc,real,cur_batch_size,z_dim. - Gli argomenti di
gen_loss()sono:gen,disc,cur_batch_size,z_dim.
Questo esercizio fa parte del corso
Deep Learning per Immagini con PyTorch
Istruzioni dell'esercizio
- Calcola la loss del discriminatore usando
disc_loss()passandogli, in quest'ordine, il generatore, il discriminatore, il campione di immagini reali, la dimensione del batch corrente e la dimensione del rumore pari a16, e assegna il risultato ad_loss. - Calcola i gradienti usando
d_loss. - Calcola la loss del generatore usando
gen_loss()passandogli, in quest'ordine, il generatore, il discriminatore, la dimensione del batch corrente e la dimensione del rumore pari a16, e assegna il risultato ag_loss. - Calcola i gradienti usando
g_loss.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break