Training a GAN model
Your team at PyBooks has made good progress in building the text generator using a Generative Adversarial Network (GAN). You have successfully defined the generator and discriminator networks. Now, it's time to train them. The final step is to generate some fake data and compare it with the real data to see how well your GAN has learned. We have used tensors as an input and the output would try to resemble the input tensors. The team at PyBooks can then use this synthetic data for text analysis as the features will have same relationship as text data.
The generator and discriminator have been initialized and saved to generator
and discriminator
, respectively.
The following variables have been initialized in the exercise:
seq_length = 5
: Length of each synthetic data sequencenum_sequences = 100
: Total number of sequences generatednum_epochs = 50
: Number of complete passes through the datasetprint_every = 10
: Output display frequency, showing results every 10 epochs
This exercise is part of the course
Deep Learning for Text with PyTorch
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)
for epoch in range(num_epochs):
for real_data in data:
# Unsqueezing real_data and prevent gradient recalculations
real_data = real_data.____(0)
noise = torch.rand((1, seq_length))
fake_data = generator(noise)
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data.____())
loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
optimizer_disc.zero_grad()
loss_disc.backward()
optimizer_disc.step()