Construindo uma U-Net: método forward
Com as camadas do encoder e do decoder definidas, você pode agora implementar o método forward() da U-Net. As entradas já foram passadas pelo encoder para você. No entanto, você precisa definir o último bloco do decoder.
O objetivo do decoder é fazer o upsampling dos mapas de características para que sua saída tenha a mesma altura e largura da imagem de entrada da U-Net. Isso permite obter máscaras semânticas no nível de pixel.
Este exercício faz parte do curso
Deep Learning para Imagens com PyTorch
Instruções do exercício
- Defina o último bloco do decoder, usando
torch.cat()para formar a conexão de atalho (skip connection).
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.enc3(self.pool(x2))
x4 = self.enc4(self.pool(x3))
x = self.upconv3(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec1(x)
x = self.upconv2(x)
x = torch.cat([x, x2], dim=1)
x = self.dec2(x)
# Define the last decoder block with skip connections
x = ____
x = ____
x = ____
return self.out(x)