1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorchによる中級ディープラーニング

Connected

演習

画像分類器のトレーニングループ

いよいよ画像分類器を学習させます! 先ほど定義した Net を使って、7種類の雲のタイプを見分けられるようにトレーニングします。

損失関数と最適化手法を定義するには、torch.nn と torch.optim の関数をそれぞれ nn と optim としてインポート済みなので利用します。トレーニングループ自体を変更する必要はありません。これまでに書いたものと同じで、学習中に損失を表示するためのロジックが少し追加されているだけです。

指示

100 XP
  • num_classes を 7 に設定して Net クラスからモデルを作成し、net に代入します。
  • 損失関数をクロスエントロピー損失として定義し、criterion に代入します。
  • 最適化手法を Adam として定義し、モデルのパラメータと学習率 0.001 を渡して optimizer に代入します。
  • dataloader_train の学習用 images と labels を反復処理して、トレーニング用の for ループを開始します。