CommencerCommencer gratuitement

Training a GAN model

Your team at PyBooks has made good progress in building the text generator using a Generative Adversarial Network (GAN). You have successfully defined the generator and discriminator networks. Now, it's time to train them. The final step is to generate some fake data and compare it with the real data to see how well your GAN has learned. We have used tensors as an input and the output would try to resemble the input tensors. The team at PyBooks can then use this synthetic data for text analysis as the features will have same relationship as text data.

The generator and discriminator have been initialized and saved to generator and discriminator, respectively.

The following variables have been initialized in the exercise:

  • seq_length = 5: Length of each synthetic data sequence
  • num_sequences = 100: Total number of sequences generated
  • num_epochs = 50: Number of complete passes through the dataset
  • print_every = 10: Output display frequency, showing results every 10 epochs

Cet exercice fait partie du cours

Deep Learning for Text with PyTorch

Afficher le cours

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for real_data in data:
      	# Unsqueezing real_data and prevent gradient recalculations
        real_data = real_data.____(0)
        noise = torch.rand((1, seq_length))
        fake_data = generator(noise)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.____())
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()
Modifier et exécuter le code