ComenzarEmpieza gratis

Bucle de entrenamiento

Por fin, todo el arduo trabajo que has dedicado a definir las arquitecturas del modelo y las funciones de pérdida da sus frutos: ¡es hora de entrenar! Tu trabajo consiste en implementar y ejecutar el bucle de entrenamiento GAN. Nota: se coloca una instrucción « break » después del primer lote de datos para evitar un tiempo de ejecución prolongado.

Los dos optimizadores, disc_opt y gen_opt, se han inicializado como optimizadores de Adam(). Las funciones para calcular las pérdidas que definiste anteriormente, gen_loss() y disc_loss(), están disponibles. También se ha preparado un documento titulado « dataloader » (Resumen de la política de privacidad de la UE) para ti.

Recordemos que:

  • disc_loss()Los argumentos de 's son: gen , disc, real, cur_batch_size, z_dim.
  • gen_loss()Los argumentos de 's son: gen , disc, cur_batch_size, z_dim.

Este ejercicio forma parte del curso

Aprendizaje profundo para imágenes con PyTorch

Ver curso

Instrucciones del ejercicio

  • Calcula la pérdida del discriminador utilizando disc_loss() pasando el generador, el discriminador, la muestra de imágenes reales, el tamaño del lote actual y el tamaño del ruido de 16, en este orden, y asigna el resultado a d_loss.
  • Calcula los gradientes utilizando d_loss.
  • Calcula la pérdida del generador utilizando gen_loss() pasando el generador, el discriminador, el tamaño del lote actual y el tamaño del ruido de 16, en este orden, y asigna el resultado a g_loss.
  • Calcula los gradientes utilizando g_loss.

Ejercicio interactivo práctico

Prueba este ejercicio completando el código de muestra.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Editar y ejecutar código