Get startedGet started for free

Training a GAN model

Your team at PyBooks has made good progress in building the text generator using a Generative Adversarial Network (GAN). You have successfully defined the generator and discriminator networks. Now, it's time to train them. The final step is to generate some fake data and compare it with the real data to see how well your GAN has learned. We have used tensors as an input and the output would try to resemble the input tensors. The team at PyBooks can then use this synthetic data for text analysis as the features will have same relationship as text data.

The generator and discriminator have been initialized and saved to generator and discriminator, respectively.

The following variables have been initialized in the exercise:

  • seq_length = 5: Length of each synthetic data sequence
  • num_sequences = 100: Total number of sequences generated
  • num_epochs = 50: Number of complete passes through the dataset
  • print_every = 10: Output display frequency, showing results every 10 epochs

This exercise is part of the course

Deep Learning for Text with PyTorch

View Course

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for real_data in data:
      	# Unsqueezing real_data and prevent gradient recalculations
        real_data = real_data.____(0)
        noise = torch.rand((1, seq_length))
        fake_data = generator(noise)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.____())
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()
Edit and Run Code