Bucle de entrenamiento
Por fin, todo el arduo trabajo que has dedicado a definir las arquitecturas del modelo y las funciones de pérdida da sus frutos: ¡es hora de entrenar! Tu trabajo consiste en implementar y ejecutar el bucle de entrenamiento GAN. Nota: se coloca una instrucción « break » después del primer lote de datos para evitar un tiempo de ejecución prolongado.
Los dos optimizadores, disc_opt y gen_opt, se han inicializado como optimizadores de Adam(). Las funciones para calcular las pérdidas que definiste anteriormente, gen_loss() y disc_loss(), están disponibles. También se ha preparado un documento titulado « dataloader » (Resumen de la política de privacidad de la UE) para ti.
Recordemos que:
disc_loss()Los argumentos de 's son:gen,disc,real,cur_batch_size,z_dim.gen_loss()Los argumentos de 's son:gen,disc,cur_batch_size,z_dim.
Este ejercicio forma parte del curso
Aprendizaje profundo para imágenes con PyTorch
Instrucciones del ejercicio
- Calcula la pérdida del discriminador utilizando
disc_loss()pasando el generador, el discriminador, la muestra de imágenes reales, el tamaño del lote actual y el tamaño del ruido de16, en este orden, y asigna el resultado ad_loss. - Calcula los gradientes utilizando
d_loss. - Calcula la pérdida del generador utilizando
gen_loss()pasando el generador, el discriminador, el tamaño del lote actual y el tamaño del ruido de16, en este orden, y asigna el resultado ag_loss. - Calcula los gradientes utilizando
g_loss.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break