Een U-Net bouwen: forward-methode
Nu de encoder- en decoderlagen gedefinieerd zijn, kun je de forward()-methode van de U-Net implementeren. De inputs zijn al door de encoder voor je gehaald. Je moet echter nog het laatste decoderblok definiëren.
Het doel van de decoder is om de feature maps te upsamplen zodat de output dezelfde hoogte en breedte heeft als de invoerafbeelding van de U-Net. Zo kun je semantische maskers op pixelniveau verkrijgen.
Deze oefening maakt deel uit van de cursus
Deep Learning voor afbeeldingen met PyTorch
Oefeninstructies
- Definieer het laatste decoderblok en gebruik
torch.cat()om de skip-verbinding te maken.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.enc3(self.pool(x2))
x4 = self.enc4(self.pool(x3))
x = self.upconv3(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec1(x)
x = self.upconv2(x)
x = torch.cat([x, x2], dim=1)
x = self.dec2(x)
# Define the last decoder block with skip connections
x = ____
x = ____
x = ____
return self.out(x)