Perda do gerador
Antes de treinar sua GAN, você precisa definir funções de perda tanto para o gerador quanto para o discriminador. Você vai começar com o primeiro.
Lembre-se de que o trabalho do gerador é criar imagens falsas que enganem o discriminador, fazendo-o classificá-las como reais. Então, o gerador perde se as imagens que ele criou forem classificadas pelo discriminador como falsas (rótulo 0).
Defina a função gen_loss() que calcula a perda do gerador. Ele precisa de quatro argumentos:
gen, o modelo do geradordisc, o modelo discriminadornum_images, o número de imagens no lotez_dim, o tamanho do ruído aleatório de entrada
Este exercício faz parte do curso
Aprendizado profundo para imagens com PyTorch
Instruções do exercício
- Crie um ruído aleatório com a forma
num_imagesusandoz_dime coloque ele emnoise. - Use o gerador para criar uma imagem falsa a partir de
noisee atribua-a afake. - Pega a previsão do discriminador pra imagem falsa que foi gerada.
- Calcule a perda dos geradores chamando
criterionnas previsões do discriminador e o tensor de uns da mesma forma.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss