1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶテキストの Deep Learning

Connected

演習

アテンション付き RNN モデルの学習とテスト

PyBooks では、以前にアテンション機構なしで単語予測用の RNN モデルを構築していました。この初期モデル(rnn_model)はすでに学習済みで、インスタンスが読み込まれています。あなたのタスクは、新しい RNNWithAttentionModel を学習し、以前の rnn_model の予測と比較することです。

次のものがあらかじめ読み込まれています:

  • inputs: テンソルとしての入力系列のリスト
  • targets: 各入力系列に対応する目標単語を含むテンソル
  • optimizer: Adam オプティマイザ関数
  • criterion: CrossEntropyLoss 関数
  • pad_sequences: バッチ化のために入力系列にパディングを行う関数
  • attention_model: 前の演習で定義したモデルクラス
  • rnn_model: PyBooks のチームが作成した学習済み RNN モデル

指示

100 XP
  • テストデータで評価する前に、RNN モデルを評価モードに設定します。
  • 適切な入力を RNN モデルに渡して、RNN の出力を取得します。
  • RNN 出力から、予測スコアが最も高い単語を取り出します。
  • 同様に、アテンションモデルでも、アテンション出力から予測スコアが最も高い単語を取り出します。