1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶ画像向け Deep Learning

Connected

演習

Generator の損失

GAN を学習させる前に、generator と discriminator の両方に対する損失関数を定義する必要があります。ここでは前者から始めます。

generator の役割は、discriminator が本物と判定してしまうような偽画像を生成することでした。そのため、生成した画像が discriminator によって偽物(ラベル 0)と分類された場合、generator は損失を受けます。

generator の損失を計算する gen_loss() 関数を定義してください。引数は4つです。

  • gen:generator モデル
  • disc:discriminator モデル
  • num_images:バッチ内の画像数
  • z_dim:入力となるランダムノイズの次元数

指示

100 XP
  • 形状が num_images × z_dim のランダムノイズを生成し、noise に代入します。
  • generator を使って noise から偽画像を生成し、fake に代入します。
  • 生成した偽画像に対する discriminator の予測を取得します。
  • discriminator の予測と、同じ形状の1からなるテンソルを引数にして criterion を呼び出し、generator の損失を計算します。