Loss del discriminator
È il momento di definire la loss per il discriminator. Ricorda che il compito del discriminator è classificare le immagini come reali o fake. Di conseguenza, il generatore subisce una loss se il discriminator classifica come reali le uscite del generatore (etichetta 1) o se classifica come fake le immagini reali (etichetta 0).
Definisci la funzione disc_loss() che calcola la loss del discriminator. Prende cinque argomenti:
gen, il modello generatoredisc, il modello discriminatorereal, un campione di immagini reali dai dati di trainingnum_images, il numero di immagini nel batchz_dim, la dimensione del rumore casuale in input
Questo esercizio fa parte del corso
Deep Learning per Immagini con PyTorch
Istruzioni dell'esercizio
- Usa il discriminator per classificare le immagini
fakee assegna le predizioni adisc_pred_fake. - Calcola la componente di loss per le fake chiamando
criterionsulle predizioni del discriminator per le immagini fake e su un tensore di zeri della stessa forma. - Usa il discriminator per classificare le immagini
reale assegna le predizioni adisc_pred_real. - Calcola la componente di loss per le real chiamando
criterionsulle predizioni del discriminator per le immagini real e su un tensore di uni della stessa forma.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss