Trainingsschleife
Endlich zahlt sich all die Arbeit aus, die du in die Modellarchitekturen und Loss-Funktionen gesteckt hast: Es ist Trainingszeit! Deine Aufgabe ist es, die GAN-Trainingsschleife zu implementieren und auszuführen. Hinweis: Eine break-Anweisung ist nach dem ersten Batch eingefügt, um eine lange Laufzeit zu vermeiden.
Die beiden Optimierer, disc_opt und gen_opt, wurden als Adam()-Optimierer initialisiert. Die Funktionen zur Berechnung der Verluste, die du zuvor definiert hast, gen_loss() und disc_loss(), stehen dir zur Verfügung. Ein dataloader ist ebenfalls vorbereitet.
Zur Erinnerung:
- Die Argumente von
disc_loss()sind:gen,disc,real,cur_batch_size,z_dim. - Die Argumente von
gen_loss()sind:gen,disc,cur_batch_size,z_dim.
Diese Übung ist Teil des Kurses
Deep Learning für Bilder mit PyTorch
Anleitung zur Übung
- Berechne den Discriminator-Loss mit
disc_loss(), indem du in dieser Reihenfolge den Generator, den Discriminator, die Stichprobe echter Bilder, die aktuelle Batchgröße und die Rauschgröße von16übergibst, und weise das Ergebnisd_losszu. - Berechne Gradienten mithilfe von
d_loss. - Berechne den Generator-Loss mit
gen_loss(), indem du in dieser Reihenfolge den Generator, den Discriminator, die aktuelle Batchgröße und die Rauschgröße von16übergibst, und weise das Ergebnisg_losszu. - Berechne Gradienten mithilfe von
g_loss.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break