1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorchによる中級ディープラーニング

Connected

演習

マルチ出力モデルの学習

複数の出力を持つモデルを学習するときは、損失関数を正しく定義することが重要です。

この例では、モデルはアルファベットと文字の2つの予測を出力します。各出力には対応する正解ラベルがあり、それぞれについて別々の損失を計算できます。1つはアルファベットの誤分類に対する損失、もう1つは文字の誤分類に対する損失です。どちらもマルチクラス分類タスクに該当するため、各回でCross-Entropy損失を適用できます。

ただし、勾配降下法で最適化できるのは1つの損失関数のみです。そのため、総損失はアルファベット損失と文字損失の合計として定義します。

指示

100 XP
  • アルファベット分類の損失を計算し、loss_alpha に代入します。
  • 文字分類の損失を計算し、loss_char に代入します。
  • 2つの部分損失の合計を総損失として計算し、loss に代入します。