Training loop
Finally, all the hard work you put into defining the model architectures and loss functions comes to fruition: it's training time! Your job is to implement and execute the GAN training loop. Note: a break
statement is placed after the first batch of data to avoid a long runtime.
The two optimizers, disc_opt
and gen_opt
, have been initialized as Adam()
optimizers. The functions to compute the losses that you defined earlier, gen_loss()
and disc_loss()
, are available to you. A dataloader
is also prepared for you.
Recall that:
disc_loss()
's arguments are:gen
,disc
,real
,cur_batch_size
,z_dim
.gen_loss()
's arguments are:gen
,disc
,cur_batch_size
,z_dim
.
This exercise is part of the course
Deep Learning for Images with PyTorch
Exercise instructions
- Calculate the discriminator loss using
disc_loss()
by passing it the generator, the discriminator, the sample of real images, current batch size, and the noise size of16
, in this order, and assign the result tod_loss
. - Calculate gradients using
d_loss
. - Calculate the generator loss using
gen_loss()
by passing it the generator, the discriminator, current batch size, and the noise size of16
, in this order, and assign the result tog_loss
. - Calculate gradients using
g_loss
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break