Discriminatorverlies
Tijd om de loss voor de discriminator te definiëren. Onthoud dat de taak van de discriminator is om afbeeldingen als echt of nep te classificeren. De discriminator krijgt dus verlies als hij de output van de generator als echt classificeert (label 1) of de echte afbeeldingen als nep (label 0).
Definieer de functie disc_loss() die het discriminatorverlies berekent. Deze neemt vijf argumenten:
gen, het generator‑modeldisc, het discriminator‑modelreal, een steekproef van echte afbeeldingen uit de trainingsdatanum_images, het aantal afbeeldingen in de batchz_dim, de grootte van de invoer‑ruisvector
Deze oefening maakt deel uit van de cursus
Deep Learning voor afbeeldingen met PyTorch
Oefeninstructies
- Gebruik de discriminator om
fakeafbeeldingen te classificeren en sla de voorspellingen op indisc_pred_fake. - Bereken de fake‑losscomponent door
criterionaan te roepen op de voorspellingen van de discriminator voor nepafbeeldingen en een tensor met nullen van dezelfde shape. - Gebruik de discriminator om
realafbeeldingen te classificeren en sla de voorspellingen op indisc_pred_real. - Bereken de real‑losscomponent door
criterionaan te roepen op de voorspellingen van de discriminator voor echte afbeeldingen en een tensor met enen van dezelfde shape.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss