Membangun U-Net: metode forward
Dengan layer encoder dan decoder telah didefinisikan, sekarang Anda dapat mengimplementasikan metode forward() dari U-Net. Masukan telah lebih dulu diteruskan melalui encoder untuk Anda. Namun, Anda perlu mendefinisikan blok decoder terakhir.
Tujuan decoder adalah melakukan upsampling pada peta fitur sehingga keluarannya memiliki tinggi dan lebar yang sama dengan citra masukan U-Net. Hal ini memungkinkan Anda memperoleh mask semantik pada tingkat piksel.
Latihan ini adalah bagian dari kursus
Deep Learning untuk Gambar dengan PyTorch
Petunjuk latihan
- Definisikan blok decoder terakhir, gunakan
torch.cat()untuk membentuk skip connection.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.enc3(self.pool(x2))
x4 = self.enc4(self.pool(x3))
x = self.upconv3(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec1(x)
x = self.upconv2(x)
x = torch.cat([x, x2], dim=1)
x = self.dec2(x)
# Define the last decoder block with skip connections
x = ____
x = ____
x = ____
return self.out(x)