ComenzarEmpieza gratis

Bucle de entrenamiento

Por fin, todo el esfuerzo de definir las arquitecturas del modelo y las funciones de pérdida da sus frutos: ¡toca entrenar! Tu tarea es implementar y ejecutar el bucle de entrenamiento de la GAN. Nota: hay una sentencia break después del primer lote de datos para evitar tiempos de ejecución largos.

Los dos optimizadores, disc_opt y gen_opt, se han inicializado como optimizadores Adam(). Las funciones para calcular las pérdidas que definiste antes, gen_loss() y disc_loss(), están disponibles. También tienes preparado un dataloader.

Recuerda que:

  • Los argumentos de disc_loss() son: gen, disc, real, cur_batch_size, z_dim.
  • Los argumentos de gen_loss() son: gen, disc, cur_batch_size, z_dim.

Este ejercicio forma parte del curso

Deep Learning para imágenes con PyTorch

Ver curso

Instrucciones del ejercicio

  • Calcula la pérdida del discriminador usando disc_loss() pasándole, en este orden, el generador, el discriminador, la muestra de imágenes reales, el tamaño de lote actual y el tamaño del ruido 16, y asigna el resultado a d_loss.
  • Calcula los gradientes usando d_loss.
  • Calcula la pérdida del generador usando gen_loss() pasándole, en este orden, el generador, el discriminador, el tamaño de lote actual y el tamaño del ruido 16, y asigna el resultado a g_loss.
  • Calcula los gradientes usando g_loss.

Ejercicio interactivo práctico

Prueba este ejercicio y completa el código de muestra.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Editar y ejecutar código