Bucle de entrenamiento
Por fin, todo el esfuerzo de definir las arquitecturas del modelo y las funciones de pérdida da sus frutos: ¡toca entrenar! Tu tarea es implementar y ejecutar el bucle de entrenamiento de la GAN. Nota: hay una sentencia break después del primer lote de datos para evitar tiempos de ejecución largos.
Los dos optimizadores, disc_opt y gen_opt, se han inicializado como optimizadores Adam(). Las funciones para calcular las pérdidas que definiste antes, gen_loss() y disc_loss(), están disponibles. También tienes preparado un dataloader.
Recuerda que:
- Los argumentos de
disc_loss()son:gen,disc,real,cur_batch_size,z_dim. - Los argumentos de
gen_loss()son:gen,disc,cur_batch_size,z_dim.
Este ejercicio forma parte del curso
Deep Learning para imágenes con PyTorch
Instrucciones del ejercicio
- Calcula la pérdida del discriminador usando
disc_loss()pasándole, en este orden, el generador, el discriminador, la muestra de imágenes reales, el tamaño de lote actual y el tamaño del ruido16, y asigna el resultado ad_loss. - Calcula los gradientes usando
d_loss. - Calcula la pérdida del generador usando
gen_loss()pasándole, en este orden, el generador, el discriminador, el tamaño de lote actual y el tamaño del ruido16, y asigna el resultado ag_loss. - Calcula los gradientes usando
g_loss.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break