ComeçarComece de graça

Perda do discriminador

É hora de definir a função de perda do discriminador. Lembre que a tarefa do discriminador é classificar imagens como reais ou falsas. Portanto, o discriminador incorre em perda se ele classificar as saídas do gerador como reais (rótulo 1) ou as imagens reais como falsas (rótulo 0).

Defina a função disc_loss() que calcula a perda do discriminador. Ela recebe cinco argumentos:

  • gen, o modelo gerador
  • disc, o modelo discriminador
  • real, uma amostra de imagens reais dos dados de treino
  • num_images, o número de imagens no lote
  • z_dim, o tamanho do ruído aleatório de entrada

Este exercício faz parte do curso

Deep Learning para Imagens com PyTorch

Ver curso

Instruções do exercício

  • Use o discriminador para classificar imagens fake e atribua as previsões a disc_pred_fake.
  • Calcule o componente de perda para as falsas chamando criterion nas previsões do discriminador para imagens falsas e em um tensor de zeros do mesmo formato.
  • Use o discriminador para classificar imagens real e atribua as previsões a disc_pred_real.
  • Calcule o componente de perda para as reais chamando criterion nas previsões do discriminador para imagens reais e em um tensor de uns do mesmo formato.

Exercício interativo prático

Experimente este exercício completando este código de exemplo.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Editar e executar o código