CommencerCommencer gratuitement

Entraîner un modèle GAN

Votre équipe chez PyBooks a bien avancé dans la création du générateur de texte avec un Generative Adversarial Network (GAN). Vous avez défini avec succès les réseaux générateur et discriminateur. Il est maintenant temps de les entraîner. L’étape finale consiste à générer des données factices et à les comparer aux données réelles pour évaluer ce que votre GAN a appris. Nous avons utilisé des tenseurs en entrée, et la sortie cherchera à ressembler aux tenseurs d’entrée. L’équipe PyBooks pourra ensuite utiliser ces données synthétiques pour l’analyse de texte, car les caractéristiques conserveront les mêmes relations que les données textuelles.

Le générateur et le discriminateur ont été initialisés et enregistrés dans generator et discriminator, respectivement.

Les variables suivantes ont été initialisées dans l’exercice :

  • seq_length = 5 : longueur de chaque séquence de données synthétiques
  • num_sequences = 100 : nombre total de séquences générées
  • num_epochs = 50 : nombre de passages complets sur l’ensemble de données
  • print_every = 10 : fréquence d’affichage des résultats, toutes les 10 époques

Cet exercice fait partie du cours

Deep Learning pour le texte avec PyTorch

Afficher le cours

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for real_data in data:
      	# Unsqueezing real_data and prevent gradient recalculations
        real_data = real_data.____(0)
        noise = torch.rand((1, seq_length))
        fake_data = generator(noise)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.____())
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()
Modifier et exécuter le code