IniziaInizia gratis

Loss del discriminator

È il momento di definire la loss per il discriminator. Ricorda che il compito del discriminator è classificare le immagini come reali o fake. Di conseguenza, il generatore subisce una loss se il discriminator classifica come reali le uscite del generatore (etichetta 1) o se classifica come fake le immagini reali (etichetta 0).

Definisci la funzione disc_loss() che calcola la loss del discriminator. Prende cinque argomenti:

  • gen, il modello generatore
  • disc, il modello discriminatore
  • real, un campione di immagini reali dai dati di training
  • num_images, il numero di immagini nel batch
  • z_dim, la dimensione del rumore casuale in input

Questo esercizio fa parte del corso

Deep Learning per Immagini con PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Usa il discriminator per classificare le immagini fake e assegna le predizioni a disc_pred_fake.
  • Calcola la componente di loss per le fake chiamando criterion sulle predizioni del discriminator per le immagini fake e su un tensore di zeri della stessa forma.
  • Usa il discriminator per classificare le immagini real e assegna le predizioni a disc_pred_real.
  • Calcola la componente di loss per le real chiamando criterion sulle predizioni del discriminator per le immagini real e su un tensore di uni della stessa forma.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Modifica ed esegui il codice