1. Nauka
  2. /
  3. Kursy
  4. /
  5. Modele Transformer w PyTorch

Connected

ćwiczenie

Dodawanie metod do klasy MultiHeadAttention

W tym ćwiczeniu zbudujesz resztę klasy MultiHeadAttention od podstaw, definiując cztery metody:

  • .split_heads(): dzieli i przekształca wejściowe osadzenia (embeddings) między głowice atencji
  • .compute_attention(): oblicza skalowane iloczyny skalarnych wag atencji pomnożonych przez macierz wartości
  • .combine_heads(): przekształca wagi atencji z powrotem do kształtu zgodnego z wejściowymi osadzeniami, x
  • .forward(): wywołuje pozostałe metody, przepuszczając wejściowe osadzenia przez każdy etap przetwarzania

torch.nn jest zaimportowany jako nn, torch.nn.functional jest dostępny jako F, a torch jest również dostępny.

Instrukcje 1/4

undefined XP
    1
    2
    3
    4
  • Podziel wejściowe osadzenia x między głowice atencji, zmieniając ich kształt na (batch_size, seq_length, self.num_heads, self.head_dim).