Training GANs
1. Training GANs
It's finally training time! But before we run the training loop, we need to define loss functions for the models.2. Generator objective
Let's think about the generator's loss function first. Recall that generator's objective is to create such fake images that would fool the discriminator into classifying them as real. The key idea is to use the discriminator to inform us about the generator's quality. We will use the generator to produce some fake images and give them to the discriminator to classify. If it misclassifies them as real (label one), the generator is doing a good job and its loss will be small. If it correctly recognizes them as fake (label zero), generator loss will be large.3. Generator loss
Let's see it in code. We define the function called gen_loss to compute generator loss. First, we define random noise as input for the generator. The noise tensor is of shape num_images, which corresponds to the batch size, by z_dim, the noise size. Then, we pass the noise to the generator to produce fake images which we then pass to the discriminator to classify. We define binary cross-entropy criterion to measure generator's performance. From the generator's perspective, it's desired that the discriminator classifies the fakes as real images which have the label one. Therefore, the generator loss is binary cross-entropy between the discriminator predictions for fakes, and the tensor of ones of the same shape, which we create with torch.ones_like.4. Discriminator objective
Let's turn to the discriminator now. Recall its objective is to correctly classify fakes and real images. To evaluate its loss, we will pass it some generator outputs to see if it correctly recognizes them as fake, or label zero. We will also pass it some real images from the training data expecting them to be classified as one, or real. Let's take a look at the code.5. Discriminator loss
We define the disc_loss function to compute discriminator loss. We will use the binary cross-entropy criterion again. First, we produce random noise as generator input like before to feed it to the generator and obtain fake images which are then passed to the discriminator. This way we get disc_pred_fake, the discriminator's predictions for the fake images. Next, we pass those predictions alongside a tensor of zeros to the criterion. This fake loss component will be larger when discriminator's prediction for fakes are real. Then, we use the discriminator to classify a sample of real images from the training data. We then compute the real loss component by passing these predictions together with a tensor of ones to the criterion. This loss component is large when the discriminator misclassifies real images as fake. Finally, to get the total discriminator loss, we average the real and fake loss components.6. GAN training loop
Let's define the GAN training loop! We iterate over epochs and real data batches from a pre-defined dataloader. For each batch, we compute the current batch size. We start with the discriminator. We reset its gradients by calling zero.grad on the discriminator optimizer, which has been pre-defined, for example as an Adam optimizer. Next, we compute the discriminator loss using the function from before and call the backward method on the optimizer to compute the gradients. Then, we perform the optimization step by calling the step method of the optimizer. We then repeat the process for the generator: reset the gradients, compute the loss using our custom function, and compute gradients, and perform the optimization step.7. Let's practice!
Now it's your turn to train a GAN!Create Your Free Account
or
By continuing, you accept our Terms of Use, our Privacy Policy and that your data is stored in the USA.