Aan de slagGa gratis aan de slag

Discriminatorverlies

Tijd om de loss voor de discriminator te definiëren. Onthoud dat de taak van de discriminator is om afbeeldingen als echt of nep te classificeren. De discriminator krijgt dus verlies als hij de output van de generator als echt classificeert (label 1) of de echte afbeeldingen als nep (label 0).

Definieer de functie disc_loss() die het discriminatorverlies berekent. Deze neemt vijf argumenten:

  • gen, het generator‑model
  • disc, het discriminator‑model
  • real, een steekproef van echte afbeeldingen uit de trainingsdata
  • num_images, het aantal afbeeldingen in de batch
  • z_dim, de grootte van de invoer‑ruisvector

Deze oefening maakt deel uit van de cursus

Deep Learning voor afbeeldingen met PyTorch

Cursus bekijken

Oefeninstructies

  • Gebruik de discriminator om fake afbeeldingen te classificeren en sla de voorspellingen op in disc_pred_fake.
  • Bereken de fake‑losscomponent door criterion aan te roepen op de voorspellingen van de discriminator voor nepafbeeldingen en een tensor met nullen van dezelfde shape.
  • Gebruik de discriminator om real afbeeldingen te classificeren en sla de voorspellingen op in disc_pred_real.
  • Bereken de real‑losscomponent door criterion aan te roepen op de voorspellingen van de discriminator voor echte afbeeldingen en een tensor met enen van dezelfde shape.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Code bewerken en uitvoeren