Perda do gerador
Antes de treinar sua GAN, você precisa definir funções de perda tanto para o gerador quanto para o discriminador. Você vai começar com o primeiro.
Lembre-se de que o trabalho do gerador é criar imagens falsas que enganem o discriminador, fazendo-o classificá-las como reais. Então, o gerador perde se as imagens que ele criou forem classificadas pelo discriminador como falsas (rótulo 0
).
Defina a função gen_loss()
que calcula a perda do gerador. Ele precisa de quatro argumentos:
gen
, o modelo do geradordisc
, o modelo discriminadornum_images
, o número de imagens no lotez_dim
, o tamanho do ruído aleatório de entrada
Este exercício faz parte do curso
Aprendizado profundo para imagens com PyTorch
Instruções do exercício
- Crie um ruído aleatório com a forma
num_images
usandoz_dim
e coloque ele emnoise
. - Use o gerador para criar uma imagem falsa a partir de
noise
e atribua-a afake
. - Pega a previsão do discriminador pra imagem falsa que foi gerada.
- Calcule a perda dos geradores chamando
criterion
nas previsões do discriminador e o tensor de uns da mesma forma.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss