1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶ画像向け Deep Learning

Connected

演習

Generator

GAN のジェネレーターは、ランダムなノイズベクトルを入力として受け取り、生成画像を出力します。アーキテクチャを再利用しやすくするため、入力と出力の形状をモデルの引数として受け取るようにします。こうすることで、異なるサイズの入力ノイズやさまざまな形状の画像に対して同じモデルを使えます。

torch.nn はすでに nn としてインポートされています。さらに、全結合層・バッチ正規化・ReLU 活性化から成るブロックを返すカスタム関数 gen_block() も利用できます。これをジェネレーターの基本ブロックとして使います。

def gen_block(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, out_dim),
        nn.BatchNorm1d(out_dim),
        nn.ReLU(inplace=True)
    )

指示

100 XP
  • self.generator を順伝播型のモデルとして定義します。
  • 最後の gen_block の後に、適切な入力サイズと out_dim を出力サイズにもつ全結合層を追加します。
  • その全結合層の後に Sigmoid 活性化を追加します。
  • forward() メソッドでは、モデルの入力を self.generator に通します。