Perda do discriminador
É hora de definir a perda para o discriminador. Lembre-se de que o trabalho do discriminador é classificar as imagens como reais ou falsas. Então, o gerador perde se classificar as saídas do gerador como reais (rótulo 1
) ou as imagens reais como falsas (rótulo 0
).
Defina a função “ disc_loss()
” que calcula a perda do discriminador. Ele precisa de cinco argumentos:
gen
, o modelo do geradordisc
, o modelo discriminadorreal
, uma amostra de imagens reais dos dados de treinamentonum_images
, o número de imagens no lotez_dim
, o tamanho do ruído aleatório de entrada
Este exercício faz parte do curso
Aprendizado profundo para imagens com PyTorch
Instruções do exercício
- Use o discriminador para classificar imagens
fake
e atribuir as previsões adisc_pred_fake
. - Calcule o componente de perda falsa chamando
criterion
nas previsões do discriminador para imagens falsas e o tensor de zeros da mesma forma. - Use o discriminador para classificar imagens
real
e atribuir as previsões adisc_pred_real
. - Calcule o componente de perda real chamando
criterion
nas previsões do discriminador para imagens reais e o tensor de uns da mesma forma.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss