Trainingslus
Eindelijk werpt al het harde werk aan de modelarchitecturen en verliesfuncties zijn vruchten af: tijd om te trainen! Jij gaat de GAN-trainingslus implementeren en uitvoeren. Let op: er staat een break-statement na de eerste batch om een lange runtijd te voorkomen.
De twee optimizers, disc_opt en gen_opt, zijn geïnitialiseerd als Adam()-optimizers. De functies om de losses te berekenen die je eerder hebt gedefinieerd, gen_loss() en disc_loss(), zijn beschikbaar. Er is ook een dataloader voor je klaargezet.
Onthoud dat:
- De argumenten van
disc_loss()zijn:gen,disc,real,cur_batch_size,z_dim. - De argumenten van
gen_loss()zijn:gen,disc,cur_batch_size,z_dim.
Deze oefening maakt deel uit van de cursus
Deep Learning voor afbeeldingen met PyTorch
Oefeninstructies
- Bereken de discriminator-loss met
disc_loss()door, in deze volgorde, de generator, de discriminator, de sample met echte afbeeldingen, de huidige batchgrootte en de ruisgrootte van16door te geven, en wijs het resultaat toe aand_loss. - Bereken gradiënten met
d_loss. - Bereken de generator-loss met
gen_loss()door, in deze volgorde, de generator, de discriminator, de huidige batchgrootte en de ruisgrootte van16door te geven, en wijs het resultaat toe aang_loss. - Bereken gradiënten met
g_loss.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break