ComenzarEmpieza gratis

Pérdida del discriminador

Es hora de definir la pérdida para el discriminador. Recordemos que la función del discriminador es clasificar las imágenes como reales o falsas. Por lo tanto, el generador incurre en una pérdida si clasifica las salidas del generador como reales (etiqueta 1) o las imágenes reales como falsas (etiqueta 0).

Define la función « disc_loss() » que calcula la pérdida discriminatoria. Toma cinco argumentos:

  • gen, el modelo del generador
  • disc, el modelo discriminador
  • real, una muestra de imágenes reales de los datos de entrenamiento.
  • num_images, el número de imágenes en el lote
  • z_dim, el tamaño del ruido aleatorio de entrada

Este ejercicio forma parte del curso

Aprendizaje profundo para imágenes con PyTorch

Ver curso

Instrucciones del ejercicio

  • Utiliza el discriminador para clasificar las imágenes de fake y asigna las predicciones a disc_pred_fake.
  • Calcula el componente de pérdida falsa llamando a criterion sobre las predicciones del discriminador para las imágenes falsas y un tensor de ceros de la misma forma.
  • Utiliza el discriminador para clasificar las imágenes de real y asigna las predicciones a disc_pred_real.
  • Calcula el componente de pérdida real llamando a criterion sobre las predicciones del discriminador para imágenes reales y el tensor de unos de la misma forma.

Ejercicio interactivo práctico

Prueba este ejercicio completando el código de muestra.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Editar y ejecutar código