U-Net bauen: forward-Methode
Mit den definierten Encoder- und Decoder-Schichten kannst du jetzt die forward()-Methode des U-Net implementieren. Die Eingaben wurden bereits für dich durch den Encoder geschickt. Du musst jedoch den letzten Decoder-Block definieren.
Ziel des Decoders ist es, die Feature-Maps hochzusampeln, sodass seine Ausgabe dieselbe Höhe und Breite hat wie das Eingabebild des U-Net. So erhältst du semantische Masken auf Pixel-Ebene.
Diese Übung ist Teil des Kurses
Deep Learning für Bilder mit PyTorch
Anleitung zur Übung
- Definiere den letzten Decoder-Block und verwende
torch.cat(), um die Skip-Connection zu bilden.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.enc3(self.pool(x2))
x4 = self.enc4(self.pool(x3))
x = self.upconv3(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec1(x)
x = self.upconv2(x)
x = torch.cat([x, x2], dim=1)
x = self.dec2(x)
# Define the last decoder block with skip connections
x = ____
x = ____
x = ____
return self.out(x)