Crear un U-Net: método forward
Con las capas del encoder y del decoder definidas, ya puedes implementar el método forward() del U-Net. Las entradas ya se han pasado por el encoder por ti. Sin embargo, necesitas definir el último bloque del decoder.
El objetivo del decoder es hacer upsampling de los mapas de características para que su salida tenga la misma altura y anchura que la imagen de entrada del U-Net. Esto te permitirá obtener máscaras semánticas a nivel de píxel.
Este ejercicio forma parte del curso
Deep Learning para imágenes con PyTorch
Instrucciones del ejercicio
- Define el último bloque del decoder, usando
torch.cat()para formar la conexión de salto.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.enc3(self.pool(x2))
x4 = self.enc4(self.pool(x3))
x = self.upconv3(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec1(x)
x = self.upconv2(x)
x = torch.cat([x, x2], dim=1)
x = self.dec2(x)
# Define the last decoder block with skip connections
x = ____
x = ____
x = ____
return self.out(x)