1. 学ぶ
  2. /
  3. コース
  4. /
  5. PyTorchで学ぶTransformerモデル

Connected

演習

MultiHeadAttentionClass の作成を始める

トークン埋め込みと位置埋め込みを作成するクラスを定義できたので、次はマルチヘッドアテンションを実行するクラスを定義します。まず、アテンション計算に使うパラメータと、入力埋め込みを query・key・value 行列に変換するための線形層、そして結合したアテンション重みを埋め込みに投影し直すための出力用の線形層を用意します。

torch.nn は nn としてインポート済みです。

指示

100 XP
  • 各アテンションヘッドが処理する埋め込み次元 head_dim を計算します。
  • 入力用の3つの層(query、key、value)と出力用の1つの層を定義し、入力層からは bias パラメータを外します。