Loss del generatore
Prima di poter addestrare la tua GAN, devi definire le funzioni di loss sia per il generatore sia per il discriminatore. Inizierai con la prima.
Ricorda che il compito del generatore è produrre immagini false tali da ingannare il discriminatore facendogliele classificare come reali. Di conseguenza, il generatore subisce una loss se le immagini che ha generato vengono classificate dal discriminatore come false (etichetta 0).
Definisci la funzione gen_loss() che calcola la loss del generatore. Accetta quattro argomenti:
gen, il modello generatoredisc, il modello discriminatorenum_images, il numero di immagini nel batchz_dim, la dimensione del rumore casuale in input
Questo esercizio fa parte del corso
Deep Learning per Immagini con PyTorch
Istruzioni dell'esercizio
- Genera rumore casuale di forma
num_imagesperz_dime assegnalo anoise. - Usa il generatore per creare un'immagine finta a partire da
noisee assegnala afake. - Ottieni la previsione del discriminatore per l'immagine finta generata.
- Calcola la loss del generatore chiamando
criterionsulle previsioni del discriminatore e su un tensore di uni della stessa forma.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss