Aan de slagGa gratis aan de slag

Een U-Net bouwen: forward-methode

Nu de encoder- en decoderlagen gedefinieerd zijn, kun je de forward()-methode van de U-Net implementeren. De inputs zijn al door de encoder voor je gehaald. Je moet echter nog het laatste decoderblok definiëren.

Het doel van de decoder is om de feature maps te upsamplen zodat de output dezelfde hoogte en breedte heeft als de invoerafbeelding van de U-Net. Zo kun je semantische maskers op pixelniveau verkrijgen.

Deze oefening maakt deel uit van de cursus

Deep Learning voor afbeeldingen met PyTorch

Cursus bekijken

Oefeninstructies

  • Definieer het laatste decoderblok en gebruik torch.cat() om de skip-verbinding te maken.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

def forward(self, x):
    x1 = self.enc1(x)
    x2 = self.enc2(self.pool(x1))
    x3 = self.enc3(self.pool(x2))
    x4 = self.enc4(self.pool(x3))

    x = self.upconv3(x4)
    x = torch.cat([x, x3], dim=1)
    x = self.dec1(x)

    x = self.upconv2(x)
    x = torch.cat([x, x2], dim=1)
    x = self.dec2(x)

    # Define the last decoder block with skip connections
    x = ____
    x = ____
    x = ____

    return self.out(x)
Code bewerken en uitvoeren