ComeçarComece de graça

Perda do discriminador

É hora de definir a perda para o discriminador. Lembre-se de que o trabalho do discriminador é classificar as imagens como reais ou falsas. Então, o gerador perde se classificar as saídas do gerador como reais (rótulo 1) ou as imagens reais como falsas (rótulo 0).

Defina a função “ disc_loss() ” que calcula a perda do discriminador. Ele precisa de cinco argumentos:

  • gen, o modelo do gerador
  • disc, o modelo discriminador
  • real, uma amostra de imagens reais dos dados de treinamento
  • num_images, o número de imagens no lote
  • z_dim, o tamanho do ruído aleatório de entrada

Este exercício faz parte do curso

Aprendizado profundo para imagens com PyTorch

Ver curso

Instruções do exercício

  • Use o discriminador para classificar imagens fake e atribuir as previsões a disc_pred_fake.
  • Calcule o componente de perda falsa chamando criterion nas previsões do discriminador para imagens falsas e o tensor de zeros da mesma forma.
  • Use o discriminador para classificar imagens real e atribuir as previsões a disc_pred_real.
  • Calcule o componente de perda real chamando criterion nas previsões do discriminador para imagens reais e o tensor de uns da mesma forma.

Exercício interativo prático

Experimente este exercício completando este código de exemplo.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Editar e executar o código