ComenzarEmpieza gratis

Pérdida del discriminador

Es momento de definir la función de pérdida para el discriminador. Recuerda que el trabajo del discriminador es clasificar las imágenes como reales o falsas. Por lo tanto, el discriminador incurre en pérdida si clasifica las salidas del generador como reales (etiqueta 1) o las imágenes reales como falsas (etiqueta 0).

Define la función disc_loss() que calcula la pérdida del discriminador. Recibe cinco argumentos:

  • gen, el modelo generador
  • disc, el modelo discriminador
  • real, una muestra de imágenes reales del conjunto de entrenamiento
  • num_images, el número de imágenes en el lote
  • z_dim, el tamaño del ruido aleatorio de entrada

Este ejercicio forma parte del curso

Deep Learning para imágenes con PyTorch

Ver curso

Instrucciones del ejercicio

  • Usa el discriminador para clasificar las imágenes fake y asigna las predicciones a disc_pred_fake.
  • Calcula el componente de pérdida para las falsas llamando a criterion sobre las predicciones del discriminador para imágenes falsas y un tensor de ceros de la misma forma.
  • Usa el discriminador para clasificar las imágenes real y asigna las predicciones a disc_pred_real.
  • Calcula el componente de pérdida para las reales llamando a criterion sobre las predicciones del discriminador para imágenes reales y un tensor de unos de la misma forma.

Ejercicio interactivo práctico

Prueba este ejercicio y completa el código de muestra.

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = ____
    # Calculate the fake loss component
    fake_loss = ____
    # Get discriminator's predictions for real images
    disc_pred_real = ____
    # Calculate the real loss component
    real_loss = ____
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss
Editar y ejecutar código