Een GAN-model trainen
Je team bij PyBooks heeft mooie vooruitgang geboekt met het bouwen van de tekstgenerator met een Generative Adversarial Network (GAN). Je hebt de generator- en discriminatornetwerken succesvol gedefinieerd. Nu is het tijd om ze te trainen. De laatste stap is om nepdata te genereren en die te vergelijken met echte data om te zien hoe goed je GAN heeft geleerd. We hebben tensors als input gebruikt en de output zal proberen op de inputtensors te lijken. Het team bij PyBooks kan deze synthetische data vervolgens gebruiken voor tekstanalyse, omdat de kenmerken dezelfde relaties zullen hebben als tekstdata.
De generator en discriminator zijn geïnitialiseerd en opgeslagen in respectievelijk generator en discriminator.
De volgende variabelen zijn in de oefening geïnitialiseerd:
seq_length = 5: Lengte van elke synthetische datasequentienum_sequences = 100: Totaal aantal gegenereerde sequentiesnum_epochs = 50: Aantal volledige passes door de gegevenssetprint_every = 10: Frequentie van de uitvoer, toont resultaten elke 10 epochs
Deze oefening maakt deel uit van de cursus
Deep Learning voor tekst met PyTorch
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)
for epoch in range(num_epochs):
for real_data in data:
# Unsqueezing real_data and prevent gradient recalculations
real_data = real_data.____(0)
noise = torch.rand((1, seq_length))
fake_data = generator(noise)
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data.____())
loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
optimizer_disc.zero_grad()
loss_disc.backward()
optimizer_disc.step()