Aan de slagGa gratis aan de slag

Een GAN-model trainen

Je team bij PyBooks heeft mooie vooruitgang geboekt met het bouwen van de tekstgenerator met een Generative Adversarial Network (GAN). Je hebt de generator- en discriminatornetwerken succesvol gedefinieerd. Nu is het tijd om ze te trainen. De laatste stap is om nepdata te genereren en die te vergelijken met echte data om te zien hoe goed je GAN heeft geleerd. We hebben tensors als input gebruikt en de output zal proberen op de inputtensors te lijken. Het team bij PyBooks kan deze synthetische data vervolgens gebruiken voor tekstanalyse, omdat de kenmerken dezelfde relaties zullen hebben als tekstdata.

De generator en discriminator zijn geïnitialiseerd en opgeslagen in respectievelijk generator en discriminator.

De volgende variabelen zijn in de oefening geïnitialiseerd:

  • seq_length = 5: Lengte van elke synthetische datasequentie
  • num_sequences = 100: Totaal aantal gegenereerde sequenties
  • num_epochs = 50: Aantal volledige passes door de gegevensset
  • print_every = 10: Frequentie van de uitvoer, toont resultaten elke 10 epochs

Deze oefening maakt deel uit van de cursus

Deep Learning voor tekst met PyTorch

Cursus bekijken

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for real_data in data:
      	# Unsqueezing real_data and prevent gradient recalculations
        real_data = real_data.____(0)
        noise = torch.rand((1, seq_length))
        fake_data = generator(noise)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.____())
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()
Code bewerken en uitvoeren