MulaiMulai sekarang secara gratis

Training loop

Akhirnya, semua kerja keras Anda dalam mendefinisikan arsitektur model dan fungsi loss membuahkan hasil: saatnya melatih model! Tugas Anda adalah mengimplementasikan dan mengeksekusi training loop GAN. Catatan: sebuah pernyataan break ditempatkan setelah batch data pertama untuk menghindari waktu eksekusi yang lama.

Dua optimizer, disc_opt dan gen_opt, telah diinisialisasi sebagai optimizer Adam(). Fungsi untuk menghitung loss yang Anda definisikan sebelumnya, yaitu gen_loss() dan disc_loss(), tersedia untuk Anda. Sebuah dataloader juga telah disiapkan.

Ingat bahwa:

  • Argumen disc_loss() adalah: gen, disc, real, cur_batch_size, z_dim.
  • Argumen gen_loss() adalah: gen, disc, cur_batch_size, z_dim.

Latihan ini adalah bagian dari kursus

Deep Learning untuk Gambar dengan PyTorch

Lihat Kursus

Petunjuk latihan

  • Hitung loss discriminator menggunakan disc_loss() dengan meneruskan generator, discriminator, sampel citra nyata, ukuran batch saat ini, dan ukuran noise 16, dalam urutan tersebut, lalu simpan hasilnya ke d_loss.
  • Hitung gradien menggunakan d_loss.
  • Hitung loss generator menggunakan gen_loss() dengan meneruskan generator, discriminator, ukuran batch saat ini, dan ukuran noise 16, dalam urutan tersebut, lalu simpan hasilnya ke g_loss.
  • Hitung gradien menggunakan g_loss.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Edit dan Jalankan Kode