1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorch で学ぶテキストの Deep Learning

Connected

演習

GANモデルの学習

PyBooks のチームは、Generative Adversarial Network (GAN) を使ったテキスト生成に向けて順調に進んでいます。ジェネレーターとディスクリミネーターのネットワーク定義は完了しました。次は、それらを学習させる段階です。最後のステップとして、偽データを生成し、実データと比較して GAN の学習状況を確認します。入力にはテンソルを用い、出力は入力テンソルに似た形になるようにします。これにより、PyBooks のチームは、テキストデータと同じ関係性をもつ特徴量を備えた合成データをテキスト分析に活用できます。

ジェネレーターとディスクリミネーターはそれぞれ generator と discriminator に初期化・保存されています。

この演習では次の変数が初期化されています。

  • seq_length = 5: 各合成データ系列の長さ
  • num_sequences = 100: 生成する系列の総数
  • num_epochs = 50: データセットを何周するか
  • print_every = 10: 出力表示の頻度(10エポックごとに結果を表示)

指示1 / 3

undefined XP
    1
    2
    3
  • 二値分類用の損失関数と Adam オプティマイザを定義します。
  • real_data を unsqueeze してディスクリミネーターを学習し、勾配の再計算を防ぎます。