Aan de slagBegin gratis

Een U-Net bouwen: forward-methode

Nu de encoder- en decoderlagen gedefinieerd zijn, kun je de forward()-methode van de U-Net implementeren. De inputs zijn al door de encoder voor je gehaald. Je moet echter nog het laatste decoderblok definiëren.

Het doel van de decoder is om de feature maps te upsamplen zodat de output dezelfde hoogte en breedte heeft als de invoerafbeelding van de U-Net. Zo kun je semantische maskers op pixelniveau verkrijgen.

Deze oefening maakt deel uit van de cursus

Deep Learning voor afbeeldingen met PyTorch

Bekijk cursus

Oefeninstructies

  • Definieer het laatste decoderblok en gebruik torch.cat() om de skip-verbinding te maken.

Interactieve oefening met praktijkervaring

Probeer deze oefening door deze voorbeeldcode aan te vullen.

def forward(self, x):
    x1 = self.enc1(x)
    x2 = self.enc2(self.pool(x1))
    x3 = self.enc3(self.pool(x2))
    x4 = self.enc4(self.pool(x3))

    x = self.upconv3(x4)
    x = torch.cat([x, x3], dim=1)
    x = self.dec1(x)

    x = self.upconv2(x)
    x = torch.cat([x, x2], dim=1)
    x = self.dec2(x)

    # Define the last decoder block with skip connections
    x = ____
    x = ____
    x = ____

    return self.out(x)
Code bewerken en uitvoeren