1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶ画像向け Deep Learning

Connected

演習

識別器の損失

識別器の損失を定義しましょう。識別器の役割は、画像が本物か偽物かを分類することでした。したがって、生成器の出力を本物(ラベル 1)と誤って判定したり、本物の画像を偽物(ラベル 0)と誤判定した場合に、識別器は損失を負います。

識別器の損失を計算する disc_loss() 関数を定義してください。引数は5つです。

  • gen:生成器モデル
  • disc:識別器モデル
  • real:学習データからの本物画像のサンプル
  • num_images:バッチ内の画像枚数
  • z_dim:入力ノイズの次元数

指示

100 XP
  • 識別器を使って fake 画像を分類し、予測を disc_pred_fake に代入します。
  • 偽画像に対する損失成分は、識別器の偽画像に対する予測と、同じ形状のゼロのテンソルを使って criterion を呼び出して計算します。
  • 識別器を使って real 画像を分類し、予測を disc_pred_real に代入します。
  • 本物画像に対する損失成分は、識別器の本物画像に対する予測と、同じ形状の1のテンソルを使って criterion を呼び出して計算します。