Perda do gerador
Antes de treinar sua GAN, você precisa definir as funções de perda para o gerador e para o discriminador. Você vai começar pelo gerador.
Lembre que a função do gerador é produzir imagens falsas que consigam enganar o discriminador, levando-o a classificá-las como reais. Portanto, o gerador sofre perda quando as imagens que ele gera são classificadas pelo discriminador como falsas (rótulo 0).
Defina a função gen_loss() que calcula a perda do gerador. Ela recebe quatro argumentos:
gen, o modelo geradordisc, o modelo discriminadornum_images, o número de imagens no lotez_dim, o tamanho do ruído aleatório de entrada
Este exercício faz parte do curso
Deep Learning para Imagens com PyTorch
Instruções do exercício
- Gere ruído aleatório com formato
num_imagesporz_dime atribua anoise. - Use o gerador para criar uma imagem falsa a partir de
noisee atribua afake. - Obtenha a previsão do discriminador para a imagem falsa gerada.
- Calcule a perda do gerador chamando
criterioncom as previsões do discriminador e um tensor de uns do mesmo formato.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss