Entraîner un modèle GAN
Votre équipe chez PyBooks a bien avancé dans la création du générateur de texte avec un Generative Adversarial Network (GAN). Vous avez défini avec succès les réseaux générateur et discriminateur. Il est maintenant temps de les entraîner. L’étape finale consiste à générer des données factices et à les comparer aux données réelles pour évaluer ce que votre GAN a appris. Nous avons utilisé des tenseurs en entrée, et la sortie cherchera à ressembler aux tenseurs d’entrée. L’équipe PyBooks pourra ensuite utiliser ces données synthétiques pour l’analyse de texte, car les caractéristiques conserveront les mêmes relations que les données textuelles.
Le générateur et le discriminateur ont été initialisés et enregistrés dans generator et discriminator, respectivement.
Les variables suivantes ont été initialisées dans l’exercice :
seq_length = 5: longueur de chaque séquence de données synthétiquesnum_sequences = 100: nombre total de séquences généréesnum_epochs = 50: nombre de passages complets sur l’ensemble de donnéesprint_every = 10: fréquence d’affichage des résultats, toutes les 10 époques
Cet exercice fait partie du cours
Deep Learning pour le texte avec PyTorch
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)
for epoch in range(num_epochs):
for real_data in data:
# Unsqueezing real_data and prevent gradient recalculations
real_data = real_data.____(0)
noise = torch.rand((1, seq_length))
fake_data = generator(noise)
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data.____())
loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
optimizer_disc.zero_grad()
loss_disc.backward()
optimizer_disc.step()