Perda do discriminador
É hora de definir a função de perda do discriminador. Lembre que a tarefa do discriminador é classificar imagens como reais ou falsas. Portanto, o discriminador incorre em perda se ele classificar as saídas do gerador como reais (rótulo 1) ou as imagens reais como falsas (rótulo 0).
Defina a função disc_loss() que calcula a perda do discriminador. Ela recebe cinco argumentos:
gen, o modelo geradordisc, o modelo discriminadorreal, uma amostra de imagens reais dos dados de treinonum_images, o número de imagens no lotez_dim, o tamanho do ruído aleatório de entrada
Este exercício faz parte do curso
Deep Learning para Imagens com PyTorch
Instruções do exercício
- Use o discriminador para classificar imagens
fakee atribua as previsões adisc_pred_fake. - Calcule o componente de perda para as falsas chamando
criterionnas previsões do discriminador para imagens falsas e em um tensor de zeros do mesmo formato. - Use o discriminador para classificar imagens
reale atribua as previsões adisc_pred_real. - Calcule o componente de perda para as reais chamando
criterionnas previsões do discriminador para imagens reais e em um tensor de uns do mesmo formato.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss