ComenzarEmpieza gratis

Pérdida del generador

Antes de poder entrenar tu GAN, debes definir funciones de pérdida tanto para el generador como para el discriminador. Empezarás con lo primero.

Recordemos que la función del generador es producir imágenes falsas que engañen al discriminador y le hagan clasificarlas como reales. Por lo tanto, el generador incurre en una pérdida si las imágenes que ha generado son clasificadas por el discriminador como falsas (etiqueta 0).

Define la función gen_loss() que calcula la pérdida del generador. Toma cuatro argumentos:

  • gen, el modelo del generador
  • disc, el modelo discriminador
  • num_images, el número de imágenes en el lote
  • z_dim, el tamaño del ruido aleatorio de entrada

Este ejercicio forma parte del curso

Aprendizaje profundo para imágenes con PyTorch

Ver curso

Instrucciones del ejercicio

  • Genera ruido aleatorio con forma num_images mediante z_dim y asígnalo a noise.
  • Utiliza el generador para crear una imagen falsa de noise y asígnala a fake.
  • Obtener la predicción del discriminador para la imagen falsa generada.
  • Calcula la pérdida de los generadores llamando a criterion sobre las predicciones del discriminador y el tensor de unos de la misma forma.

Ejercicio interactivo práctico

Prueba este ejercicio completando el código de muestra.

def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = ____(num_images, z_dim)
    # Generate fake image
    fake = ____
    # Get discriminator's prediction on the fake image
    disc_pred = ____
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = ____(____, ____)
    return gen_loss
Editar y ejecutar código