1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶテキストの Deep Learning

Connected

演習

テキスト用の RNN モデルを構築する

PyBooks のデータアナリストとして、顧客とのやり取り、時系列データ、テキスト文書などの時系列・系列データに頻繁に出会います。RNN はこのようなデータの分析とインサイト抽出に有効です。この演習では、すでに前処理とエンコードが済んだ Newsgroup データセットを扱います。データセットには複数カテゴリの記事が含まれています。あなたのタスクは、RNN を用いて以下の3カテゴリに記事を分類することです。

rec.autos、sci.med、comp.graphics

次のモジュールはあらかじめ読み込まれています: torch、nn、optim。

さらに、パラメータ input_size、hidden_size (32)、num_layers (2)、num_classes も事前に用意されています。

この演習および以降の演習では、sklearn の fetch_20newsgroups データセットを使用します。

指示

100 XP
  • RNN クラスに RNN レイヤーと全結合の Linear レイヤーを追加して完成させます。
  • モデルを初期化します。
  • 勾配をゼロにしながら、RNN モデルを 10 エポック学習させます。