Aan de slagGa gratis aan de slag

Trainingslus

Eindelijk werpt al het harde werk aan de modelarchitecturen en verliesfuncties zijn vruchten af: tijd om te trainen! Jij gaat de GAN-trainingslus implementeren en uitvoeren. Let op: er staat een break-statement na de eerste batch om een lange runtijd te voorkomen.

De twee optimizers, disc_opt en gen_opt, zijn geïnitialiseerd als Adam()-optimizers. De functies om de losses te berekenen die je eerder hebt gedefinieerd, gen_loss() en disc_loss(), zijn beschikbaar. Er is ook een dataloader voor je klaargezet.

Onthoud dat:

  • De argumenten van disc_loss() zijn: gen, disc, real, cur_batch_size, z_dim.
  • De argumenten van gen_loss() zijn: gen, disc, cur_batch_size, z_dim.

Deze oefening maakt deel uit van de cursus

Deep Learning voor afbeeldingen met PyTorch

Cursus bekijken

Oefeninstructies

  • Bereken de discriminator-loss met disc_loss() door, in deze volgorde, de generator, de discriminator, de sample met echte afbeeldingen, de huidige batchgrootte en de ruisgrootte van 16 door te geven, en wijs het resultaat toe aan d_loss.
  • Bereken gradiënten met d_loss.
  • Bereken de generator-loss met gen_loss() door, in deze volgorde, de generator, de discriminator, de huidige batchgrootte en de ruisgrootte van 16 door te geven, en wijs het resultaat toe aan g_loss.
  • Bereken gradiënten met g_loss.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Code bewerken en uitvoeren