Membangun U-Net: metode forward
Dengan layer encoder dan decoder telah didefinisikan, sekarang Anda dapat mengimplementasikan metode forward() dari U-Net. Masukan telah lebih dulu diteruskan melalui encoder untuk Anda. Namun, Anda perlu mendefinisikan blok decoder terakhir.
Tujuan decoder adalah melakukan upsampling pada peta fitur sehingga keluarannya memiliki tinggi dan lebar yang sama dengan citra masukan U-Net. Hal ini memungkinkan Anda memperoleh mask semantik pada tingkat piksel.
Latihan ini merupakan bagian dari kursus
Deep Learning untuk Gambar dengan PyTorch
Instruksi latihan
- Definisikan blok decoder terakhir, gunakan
torch.cat()untuk membentuk skip connection.
Latihan interaktif langsung praktik
Cobalah latihan ini dengan melengkapi kode contoh ini.
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.enc3(self.pool(x2))
x4 = self.enc4(self.pool(x3))
x = self.upconv3(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec1(x)
x = self.upconv2(x)
x = torch.cat([x, x2], dim=1)
x = self.dec2(x)
# Define the last decoder block with skip connections
x = ____
x = ____
x = ____
return self.out(x)