Pérdida del discriminador
Es momento de definir la función de pérdida para el discriminador. Recuerda que el trabajo del discriminador es clasificar las imágenes como reales o falsas. Por lo tanto, el discriminador incurre en pérdida si clasifica las salidas del generador como reales (etiqueta 1) o las imágenes reales como falsas (etiqueta 0).
Define la función disc_loss() que calcula la pérdida del discriminador. Recibe cinco argumentos:
gen, el modelo generadordisc, el modelo discriminadorreal, una muestra de imágenes reales del conjunto de entrenamientonum_images, el número de imágenes en el lotez_dim, el tamaño del ruido aleatorio de entrada
Este ejercicio forma parte del curso
Deep Learning para imágenes con PyTorch
Instrucciones del ejercicio
- Usa el discriminador para clasificar las imágenes
fakey asigna las predicciones adisc_pred_fake. - Calcula el componente de pérdida para las falsas llamando a
criterionsobre las predicciones del discriminador para imágenes falsas y un tensor de ceros de la misma forma. - Usa el discriminador para clasificar las imágenes
realy asigna las predicciones adisc_pred_real. - Calcula el componente de pérdida para las reales llamando a
criterionsobre las predicciones del discriminador para imágenes reales y un tensor de unos de la misma forma.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss