Pérdida del generador
Antes de poder entrenar tu GAN, necesitas definir las funciones de pérdida tanto para el generador como para el discriminador. Empezarás por la del primero.
Recuerda que el objetivo del generador es producir imágenes falsas que consigan engañar al discriminador para que las clasifique como reales. Por lo tanto, el generador sufre pérdida si las imágenes que genera son clasificadas por el discriminador como falsas (etiqueta 0).
Define la función gen_loss() que calcule la pérdida del generador. Recibe cuatro argumentos:
gen, el modelo generadordisc, el modelo discriminadornum_images, el número de imágenes en el lotez_dim, el tamaño del ruido aleatorio de entrada
Este ejercicio forma parte del curso
Deep Learning para imágenes con PyTorch
Instrucciones del ejercicio
- Genera ruido aleatorio de forma
num_imagesporz_dimy asígnalo anoise. - Usa el generador para crear una imagen falsa a partir de
noisey asígnala afake. - Obtén la predicción del discriminador para la imagen falsa generada.
- Calcula la pérdida del generador llamando a
criterioncon las predicciones del discriminador y un tensor de unos de la misma forma.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss