CommencerCommencer gratuitement

Ajout de méthodes à la classe MultiHeadAttention

Dans cet exercice, vous allez compléter le reste de la classe MultiHeadAttention en partant de zéro en définissant quatre méthodes :

  • .split_heads() : découper et transformer les embeddings d’entrée entre les têtes d’attention
  • .compute_attention() : calculer l’attention par produit scalaire mis à l’échelle, puis la multiplier par la matrice des valeurs
  • .combine_heads() : retransformer les poids d’attention pour retrouver la même forme que les embeddings d’entrée, x
  • .forward() : appeler les autres méthodes pour faire passer les embeddings d’entrée à travers chaque étape

torch.nn a été importé sous le nom nn, torch.nn.functional est disponible sous F, et torch est également disponible.

Cet exercice fait partie du cours

Modèles Transformer avec PyTorch

Afficher le cours

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        seq_length = x.size(1)
        # Split the input embeddings and permute
        x = x.____
        return x.permute(0, 2, 1, 3)
Modifier et exécuter le code