Aan de slagGa gratis aan de slag

Generatorverlies

Voordat je je GAN kunt trainen, moet je verliesfuncties definiëren voor zowel de generator als de discriminator. Je begint met de eerste.

Onthoud dat het doel van de generator is om zulke nepafbeeldingen te maken dat de discriminator ze als echt classificeert. Daarom krijgt de generator verlies als de door hem gegenereerde afbeeldingen door de discriminator als nep (label 0) worden geclassificeerd.

Definieer de functie gen_loss() die het verlies van de generator berekent. Deze neemt vier argumenten:

  • gen, het generatormodel
  • disc, het discriminatormodel
  • num_images, het aantal afbeeldingen in de batch
  • z_dim, de grootte van de invoer-ruis

Deze oefening maakt deel uit van de cursus

Deep Learning voor afbeeldingen met PyTorch

Cursus bekijken

Oefeninstructies

  • Genereer willekeurige ruis met vorm num_images bij z_dim en wijs dit toe aan noise.
  • Gebruik de generator om van noise een nepafbeelding te genereren en wijs dit toe aan fake.
  • Haal de voorspelling van de discriminator op voor de gegenereerde nepafbeelding.
  • Bereken het verlies van de generator door criterion aan te roepen op de voorspellingen van de discriminator en een tensor van enen met dezelfde vorm.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Code bewerken en uitvoeren