Training loop
Akhirnya, semua kerja keras Anda dalam mendefinisikan arsitektur model dan fungsi loss membuahkan hasil: saatnya melatih model! Tugas Anda adalah mengimplementasikan dan mengeksekusi training loop GAN. Catatan: sebuah pernyataan break ditempatkan setelah batch data pertama untuk menghindari waktu eksekusi yang lama.
Dua optimizer, disc_opt dan gen_opt, telah diinisialisasi sebagai optimizer Adam(). Fungsi untuk menghitung loss yang Anda definisikan sebelumnya, yaitu gen_loss() dan disc_loss(), tersedia untuk Anda. Sebuah dataloader juga telah disiapkan.
Ingat bahwa:
- Argumen
disc_loss()adalah:gen,disc,real,cur_batch_size,z_dim. - Argumen
gen_loss()adalah:gen,disc,cur_batch_size,z_dim.
Latihan ini adalah bagian dari kursus
Deep Learning untuk Gambar dengan PyTorch
Petunjuk latihan
- Hitung loss discriminator menggunakan
disc_loss()dengan meneruskan generator, discriminator, sampel citra nyata, ukuran batch saat ini, dan ukuran noise16, dalam urutan tersebut, lalu simpan hasilnya ked_loss. - Hitung gradien menggunakan
d_loss. - Hitung loss generator menggunakan
gen_loss()dengan meneruskan generator, discriminator, ukuran batch saat ini, dan ukuran noise16, dalam urutan tersebut, lalu simpan hasilnya keg_loss. - Hitung gradien menggunakan
g_loss.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break