1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorchで学ぶTransformerモデル

Connected

演習

マルチヘッド・アテンションの実装

独自の MultiHeadAttention クラスを実装する前に、このクラスを実際に使って、クエリ・キー・バリューの各行列がどのように変換されるかを確認します。これらの行列は、学習された重みによる線形変換で入力埋め込みを射影して生成されることを思い出してください。

query、key、value の各行列はすでに用意されており、MultiHeadAttention も定義済みです。

指示

100 XP
  • アテンションヘッドを 8 個、入力埋め込みの次元を 512 としてパラメータを定義します。
  • 定義したパラメータを使って MultiHeadAttention クラスのインスタンスを作成します。
  • query、key、value の各行列を multihead_attn メカニズムに通します。