Faltungsgenerator
Definiere einen Convolutional Generator nach den DCGAN-Richtlinien, die im letzten Video besprochen wurden.
torch.nn
wurde für dich schon mal als „ nn
” importiert. Außerdem gibt's eine eigene Funktion namens „ dc_gen_block()
“, die einen Block mit einer transponierten Faltung, Batch-Normalisierung und ReLU-Aktivierung zurückgibt. Diese Funktion ist eine wichtige Grundlage für den Aufbau des Faltungsgenerators. Du kannst dich unten mit der Definition von „ dc_gen_block()
” vertraut machen.
def dc_gen_block(in_dim, out_dim, kernel_size, stride):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride=stride),
nn.BatchNorm2d(out_dim),
nn.ReLU()
)
Diese Übung ist Teil des Kurses
Deep Learning für Bilder mit PyTorch
Anleitung zur Übung
- Füge den letzten Generatorblock hinzu und ordne die Größe der Feature-Maps „
256
“ zu. - Füge eine transponierte Faltung mit der Ausgabegröße „
3
“ hinzu. - Füge die tanh-Aktivierung hinzu.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
class DCGenerator(nn.Module):
def __init__(self, in_dim, kernel_size=4, stride=2):
super(DCGenerator, self).__init__()
self.in_dim = in_dim
self.gen = nn.Sequential(
dc_gen_block(in_dim, 1024, kernel_size, stride),
dc_gen_block(1024, 512, kernel_size, stride),
# Add last generator block
____,
# Add transposed convolution
____(____, ____, kernel_size, stride=stride),
# Add tanh activation
____
)
def forward(self, x):
x = x.view(len(x), self.in_dim, 1, 1)
return self.gen(x)