Pérdida del generador
Antes de poder entrenar tu GAN, debes definir funciones de pérdida tanto para el generador como para el discriminador. Empezarás con lo primero.
Recordemos que la función del generador es producir imágenes falsas que engañen al discriminador y le hagan clasificarlas como reales. Por lo tanto, el generador incurre en una pérdida si las imágenes que ha generado son clasificadas por el discriminador como falsas (etiqueta 0
).
Define la función gen_loss()
que calcula la pérdida del generador. Toma cuatro argumentos:
gen
, el modelo del generadordisc
, el modelo discriminadornum_images
, el número de imágenes en el lotez_dim
, el tamaño del ruido aleatorio de entrada
Este ejercicio forma parte del curso
Aprendizaje profundo para imágenes con PyTorch
Instrucciones del ejercicio
- Genera ruido aleatorio con forma
num_images
mediantez_dim
y asígnalo anoise
. - Utiliza el generador para crear una imagen falsa de
noise
y asígnala afake
. - Obtener la predicción del discriminador para la imagen falsa generada.
- Calcula la pérdida de los generadores llamando a
criterion
sobre las predicciones del discriminador y el tensor de unos de la misma forma.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
def gen_loss(gen, disc, criterion, num_images, z_dim):
# Define random noise
noise = ____(num_images, z_dim)
# Generate fake image
fake = ____
# Get discriminator's prediction on the fake image
disc_pred = ____
# Compute generator loss
criterion = nn.BCEWithLogitsLoss()
gen_loss = ____(____, ____)
return gen_loss