ComenzarEmpieza gratis

Crear un U-Net: método forward

Con las capas del encoder y del decoder definidas, ya puedes implementar el método forward() del U-Net. Las entradas ya se han pasado por el encoder por ti. Sin embargo, necesitas definir el último bloque del decoder.

El objetivo del decoder es hacer upsampling de los mapas de características para que su salida tenga la misma altura y anchura que la imagen de entrada del U-Net. Esto te permitirá obtener máscaras semánticas a nivel de píxel.

Este ejercicio forma parte del curso

Deep Learning para imágenes con PyTorch

Ver curso

Instrucciones del ejercicio

  • Define el último bloque del decoder, usando torch.cat() para formar la conexión de salto.

Ejercicio interactivo práctico

Prueba este ejercicio y completa el código de muestra.

def forward(self, x):
    x1 = self.enc1(x)
    x2 = self.enc2(self.pool(x1))
    x3 = self.enc3(self.pool(x2))
    x4 = self.enc4(self.pool(x3))

    x = self.upconv3(x4)
    x = torch.cat([x, x3], dim=1)
    x = self.dec1(x)

    x = self.upconv2(x)
    x = torch.cat([x, x2], dim=1)
    x = self.dec2(x)

    # Define the last decoder block with skip connections
    x = ____
    x = ____
    x = ____

    return self.out(x)
Editar y ejecutar código