Pérdida del discriminador
Es hora de definir la pérdida para el discriminador. Recordemos que la función del discriminador es clasificar las imágenes como reales o falsas. Por lo tanto, el generador incurre en una pérdida si clasifica las salidas del generador como reales (etiqueta 1
) o las imágenes reales como falsas (etiqueta 0
).
Define la función « disc_loss()
» que calcula la pérdida discriminatoria. Toma cinco argumentos:
gen
, el modelo del generadordisc
, el modelo discriminadorreal
, una muestra de imágenes reales de los datos de entrenamiento.num_images
, el número de imágenes en el lotez_dim
, el tamaño del ruido aleatorio de entrada
Este ejercicio forma parte del curso
Aprendizaje profundo para imágenes con PyTorch
Instrucciones del ejercicio
- Utiliza el discriminador para clasificar las imágenes de
fake
y asigna las predicciones adisc_pred_fake
. - Calcula el componente de pérdida falsa llamando a
criterion
sobre las predicciones del discriminador para las imágenes falsas y un tensor de ceros de la misma forma. - Utiliza el discriminador para clasificar las imágenes de
real
y asigna las predicciones adisc_pred_real
. - Calcula el componente de pérdida real llamando a
criterion
sobre las predicciones del discriminador para imágenes reales y el tensor de unos de la misma forma.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
# Get discriminator's predictions for fake images
disc_pred_fake = ____
# Calculate the fake loss component
fake_loss = ____
# Get discriminator's predictions for real images
disc_pred_real = ____
# Calculate the real loss component
real_loss = ____
disc_loss = (real_loss + fake_loss) / 2
return disc_loss