LoslegenKostenlos loslegen

Trainingsschleife

Endlich zahlt sich die ganze harte Arbeit aus, die du in die Definition der Modellarchitekturen und Verlustfunktionen gesteckt hast: Es ist Zeit für das Training! Deine Aufgabe ist es, den GAN-Trainingszyklus umzusetzen und durchzuführen. Hinweis: Nach dem ersten Datenbatch kommt eine „ break “-Anweisung, damit das Ganze nicht ewig läuft.

Die beiden Optimierer „ disc_opt “ und „ gen_opt “ wurden als Optimierer „ Adam() “ eingerichtet. Die Funktionen zum Berechnen der zuvor definierten Verluste, „ gen_loss() “ und „ disc_loss() “, stehen dir zur Verfügung. Ein „ dataloader ” (Handbuch für den ersten Einsatz) steht dir ebenfalls zur Verfügung.

Denk dran:

  • disc_loss()Die Argumente lauten: gen , disc, real, cur_batch_size, z_dim.
  • gen_loss()Die Argumente lauten: gen , disc, cur_batch_size, z_dim.

Diese Übung ist Teil des Kurses

Deep Learning für Bilder mit PyTorch

Kurs anzeigen

Anleitung zur Übung

  • Berechne den Diskriminatorverlust mit „ disc_loss() “, indem du den Generator, den Diskriminator, die Stichprobe der echten Bilder, die aktuelle Batchgröße und die Rauschgröße von „ 16 “ in dieser Reihenfolge übergibst, und speichere das Ergebnis in „ d_loss “.
  • Berechne Gradienten mit „ d_loss “.
  • Berechne den Generatorverlust mit „ gen_loss() “, indem du den Generator, den Diskriminator, die aktuelle Batchgröße und die Rauschgröße von „ 16 “ in dieser Reihenfolge übergibst, und speichere das Ergebnis in „ g_loss “.
  • Berechne Gradienten mit „ g_loss “.

Interaktive Übung

Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Code bearbeiten und ausführen