1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶ画像向け Deep Learning

Connected

演習

トレーニングループ

ついに、これまで定義してきたモデル構造と損失関数が活躍するときです。いよいよ学習を実行します。あなたの役割は、GAN のトレーニングループを実装して実行することです。注意: 実行時間を長くしないため、最初のバッチの後で break 文が入っています。

2 つのオプティマイザ disc_opt と gen_opt は、Adam() オプティマイザとして初期化済みです。先ほど定義した損失を計算する関数 gen_loss() と disc_loss() が利用できます。dataloader も用意されています。

次を思い出してください。

  • disc_loss() の引数は順に gen, disc, real, cur_batch_size, z_dim です。
  • gen_loss() の引数は順に gen, disc, cur_batch_size, z_dim です。

指示

100 XP
  • 生成器、識別器、実画像のサンプル、現在のバッチサイズ、ノイズサイズ 16 の順で disc_loss() に渡して識別器の損失を計算し、結果を d_loss に代入します。
  • d_loss を使って勾配を計算します。
  • 生成器、識別器、現在のバッチサイズ、ノイズサイズ 16 の順で gen_loss() に渡して生成器の損失を計算し、結果を g_loss に代入します。
  • g_loss を使って勾配を計算します。