1. Lära sig
  2. /
  3. Courses
  4. /
  5. PyTorchで学ぶTransformerモデル

Connected

exercise

MultiHeadAttention クラスにメソッドを追加する

この演習では、MultiHeadAttention クラスの残りを一から実装し、次の4つのメソッドを定義します。

  • .split_heads(): 入力埋め込みをアテンションヘッド間に分割・変換します
  • .compute_attention(): スケールド・ドット積により計算したアテンション重みを values 行列に掛け合わせます
  • .combine_heads(): アテンション重みを入力埋め込み x と同じ形状に戻します
  • .forward(): これらのメソッドを呼び出し、入力埋め込みを各処理に通します

torch.nn は nn としてインポート済みで、torch.nn.functional は F として、torch も利用可能です。

Instruktioner 1 / 4

undefined XP
    1
    2
    3
    4
  • 入力埋め込み x を、(batch_size, seq_length, self.num_heads, self.head_dim) にリシェイプしてアテンションヘッド間に分割します。