CommencerCommencez gratuitement

Ajout de méthodes à la classe MultiHeadAttention

Dans cet exercice, vous allez compléter le reste de la classe MultiHeadAttention en partant de zéro en définissant quatre méthodes :

  • .split_heads() : découper et transformer les embeddings d’entrée entre les têtes d’attention
  • .compute_attention() : calculer l’attention par produit scalaire mis à l’échelle, puis la multiplier par la matrice des valeurs
  • .combine_heads() : retransformer les poids d’attention pour retrouver la même forme que les embeddings d’entrée, x
  • .forward() : appeler les autres méthodes pour faire passer les embeddings d’entrée à travers chaque étape

torch.nn a été importé sous le nom nn, torch.nn.functional est disponible sous F, et torch est également disponible.

Cet exercice fait partie du cours

<cours>Modèles Transformer avec PyTorch</cours>
Voir le cours

Exercice interactif pratique

Essayez cet exercice en complétant ce code d’exemple.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        seq_length = x.size(1)
        # Split the input embeddings and permute
        x = x.____
        return x.permute(0, 2, 1, 3)
Modifier et exécuter le code