1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorchによる中級ディープラーニング

Connected

演習

LSTM ネットワーク

ご存じのとおり、プレーンな RNN セルは実務ではあまり使われません。長い系列をより適切に扱える、より一般的な選択肢が Long Short-Term Memory セル、つまり LSTM です。この演習では、LSTM ネットワークを自分で実装してみます!

先ほど作成した RNN ネットワークとの最も重要な実装上の違いは、LSTM では隠れ状態が1つではなく2つあることです。つまり、この追加の隠れ状態を初期化し、LSTM セルに渡す必要があります。

torch と torch.nn はすでにインポート済みです。さっそくコーディングを始めましょう!

指示

100 XP
  • .__init__() メソッドで、LSTM レイヤーを定義して self.lstm に代入します。
  • forward() メソッドで、最初の長期記憶の隠れ状態 c0 をゼロで初期化します。
  • forward() メソッドで、現在のタイムステップの入力と、2つの隠れ状態を含むタプルの3つを LSTM レイヤーに渡します。