IniziaInizia gratis

Addestrare un modello GAN

Il tuo team di PyBooks ha fatto buoni progressi nel costruire il generatore di testo usando una Generative Adversarial Network (GAN). Hai definito con successo le reti generator e discriminator. Ora è il momento di addestrarle. L’ultimo passaggio è generare alcuni dati fittizi e confrontarli con i dati reali per vedere quanto bene la tua GAN ha imparato. Abbiamo usato tensori come input e l’output cercherà di somigliare ai tensori di input. Il team di PyBooks potrà poi usare questi dati sintetici per l’analisi del testo, perché le feature manterranno le stesse relazioni dei dati testuali.

Il generator e il discriminator sono stati inizializzati e salvati rispettivamente in generator e discriminator.

Nel seguente esercizio sono state inizializzate queste variabili:

  • seq_length = 5: Lunghezza di ciascuna sequenza di dati sintetici
  • num_sequences = 100: Numero totale di sequenze generate
  • num_epochs = 50: Numero di passaggi completi sul dataset
  • print_every = 10: Frequenza di visualizzazione dei risultati, mostra l’output ogni 10 epoche

Questo esercizio fa parte del corso

Deep Learning per il testo con PyTorch

Visualizza il corso

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for real_data in data:
      	# Unsqueezing real_data and prevent gradient recalculations
        real_data = real_data.____(0)
        noise = torch.rand((1, seq_length))
        fake_data = generator(noise)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.____())
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()
Modifica ed esegui il codice