BaşlayınÜcretsiz Başlayın

Training loop

Finally, all the hard work you put into defining the model architectures and loss functions comes to fruition: it's training time! Your job is to implement and execute the GAN training loop. Note: a break statement is placed after the first batch of data to avoid a long runtime.

The two optimizers, disc_opt and gen_opt, have been initialized as Adam() optimizers. The functions to compute the losses that you defined earlier, gen_loss() and disc_loss(), are available to you. A dataloader is also prepared for you.

Recall that:

  • disc_loss()'s arguments are: gen, disc, real, cur_batch_size, z_dim.
  • gen_loss()'s arguments are: gen, disc, cur_batch_size, z_dim.

Bu egzersiz

Deep Learning for Images with PyTorch

kursunun bir parçasıdır
Kursu Görüntüle

Egzersiz talimatları

  • Calculate the discriminator loss using disc_loss() by passing it the generator, the discriminator, the sample of real images, current batch size, and the noise size of 16, in this order, and assign the result to d_loss.
  • Calculate gradients using d_loss.
  • Calculate the generator loss using gen_loss() by passing it the generator, the discriminator, current batch size, and the noise size of 16, in this order, and assign the result to g_loss.
  • Calculate gradients using g_loss.

Uygulamalı interaktif egzersiz

Bu örnek kodu tamamlayarak bu egzersizi bitirin.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Kodu Düzenle ve Çalıştır