Bucle de entrenamiento
Por fin, todo el arduo trabajo que has dedicado a definir las arquitecturas del modelo y las funciones de pérdida da sus frutos: ¡es hora de entrenar! Tu trabajo consiste en implementar y ejecutar el bucle de entrenamiento GAN. Nota: se coloca una instrucción « break
» después del primer lote de datos para evitar un tiempo de ejecución prolongado.
Los dos optimizadores, disc_opt
y gen_opt
, se han inicializado como optimizadores de Adam()
. Las funciones para calcular las pérdidas que definiste anteriormente, gen_loss()
y disc_loss()
, están disponibles. También se ha preparado un documento titulado « dataloader
» (Resumen de la política de privacidad de la UE) para ti.
Recordemos que:
disc_loss()
Los argumentos de 's son:gen
,disc
,real
,cur_batch_size
,z_dim
.gen_loss()
Los argumentos de 's son:gen
,disc
,cur_batch_size
,z_dim
.
Este ejercicio forma parte del curso
Aprendizaje profundo para imágenes con PyTorch
Instrucciones del ejercicio
- Calcula la pérdida del discriminador utilizando
disc_loss()
pasando el generador, el discriminador, la muestra de imágenes reales, el tamaño del lote actual y el tamaño del ruido de16
, en este orden, y asigna el resultado ad_loss
. - Calcula los gradientes utilizando
d_loss
. - Calcula la pérdida del generador utilizando
gen_loss()
pasando el generador, el discriminador, el tamaño del lote actual y el tamaño del ruido de16
, en este orden, y asigna el resultado ag_loss
. - Calcula los gradientes utilizando
g_loss
.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break