LoslegenKostenlos loslegen

Trainingsschleife

Endlich zahlt sich all die Arbeit aus, die du in die Modellarchitekturen und Loss-Funktionen gesteckt hast: Es ist Trainingszeit! Deine Aufgabe ist es, die GAN-Trainingsschleife zu implementieren und auszuführen. Hinweis: Eine break-Anweisung ist nach dem ersten Batch eingefügt, um eine lange Laufzeit zu vermeiden.

Die beiden Optimierer, disc_opt und gen_opt, wurden als Adam()-Optimierer initialisiert. Die Funktionen zur Berechnung der Verluste, die du zuvor definiert hast, gen_loss() und disc_loss(), stehen dir zur Verfügung. Ein dataloader ist ebenfalls vorbereitet.

Zur Erinnerung:

  • Die Argumente von disc_loss() sind: gen, disc, real, cur_batch_size, z_dim.
  • Die Argumente von gen_loss() sind: gen, disc, cur_batch_size, z_dim.

Diese Übung ist Teil des Kurses

Deep Learning für Bilder mit PyTorch

Kurs anzeigen

Anleitung zur Übung

  • Berechne den Discriminator-Loss mit disc_loss(), indem du in dieser Reihenfolge den Generator, den Discriminator, die Stichprobe echter Bilder, die aktuelle Batchgröße und die Rauschgröße von 16 übergibst, und weise das Ergebnis d_loss zu.
  • Berechne Gradienten mithilfe von d_loss.
  • Berechne den Generator-Loss mit gen_loss(), indem du in dieser Reihenfolge den Generator, den Discriminator, die aktuelle Batchgröße und die Rauschgröße von 16 übergibst, und weise das Ergebnis g_loss zu.
  • Berechne Gradienten mithilfe von g_loss.

Interaktive Übung

Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Code bearbeiten und ausführen