Generator loss
Before you can train your GAN, you need to define loss functions for both the generator and the discriminator. You will start with the former.
Recall that the generator's job is to produce such fake images that would fool the discriminator into classifying them as real. Therefore, the generator incurs a loss if the images it generated are classified by the discriminator as fake (label 0).
Define the gen_loss() function that calculates the generator loss. It takes four arguments:
gen, the generator modeldisc, the discriminator modelnum_images, the number of images in batchz_dim, the size of the input random noise
Bu egzersiz
Deep Learning for Images with PyTorch
kursunun bir parçasıdırEgzersiz talimatları
- Generate random noise of shape
num_imagesbyz_dimand assign it tonoise. - Use the generator to generate a fake image from for
noiseand assign it tofake. - Get discriminator's prediction for the generated fake image.
- Compute generators loss by calling
criterionon discriminator's predictions and the a tensor of ones of the same shape.
Uygulamalı interaktif egzersiz
Bu örnek kodu tamamlayarak bu egzersizi bitirin.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss