Discriminator-Loss
Jetzt definieren wir den Loss für den Discriminator. Denk daran: Die Aufgabe des Discriminators ist es, Bilder als echt oder fake zu klassifizieren. Der Generator verursacht also einen Loss, wenn er die Ausgaben des Generators als echt (Label 1) oder die echten Bilder als fake (Label 0) klassifiziert.
Definiere die Funktion disc_loss(), die den Discriminator-Loss berechnet. Sie nimmt fünf Argumente entgegen:
gen, das Generatormodelldisc, das Discriminatormodellreal, eine Stichprobe echter Bilder aus den Trainingsdatennum_images, die Anzahl der Bilder im Batchz_dim, die Größe des Eingabe-Rauschvektors
Diese Übung ist Teil des Kurses
Deep Learning für Bilder mit PyTorch
Anleitung zur Übung
- Verwende den Discriminator, um
fake-Bilder zu klassifizieren, und speichere die Vorhersagen indisc_pred_fake. - Berechne die Fake-Loss-Komponente, indem du
criterionauf die Vorhersagen des Discriminators für Fake-Bilder und auf einen Tensor aus Nullen gleicher Form anwendest. - Verwende den Discriminator, um
real-Bilder zu klassifizieren, und speichere die Vorhersagen indisc_pred_real. - Berechne die Real-Loss-Komponente, indem du
criterionauf die Vorhersagen des Discriminators für echte Bilder und auf einen Tensor aus Einsen gleicher Form anwendest.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss