IniziaInizia gratis

Aggiungere metodi alla classe MultiHeadAttention

In questo esercizio costruirai da zero il resto della classe MultiHeadAttention definendo quattro metodi:

  • .split_heads(): suddivide e trasforma gli embedding di input tra le teste di attention
  • .compute_attention(): calcola i pesi della scaled dot-product attention moltiplicati per la matrice dei values
  • .combine_heads(): riconverte i pesi dell'attenzione alla stessa forma degli embedding di input, x
  • .forward(): richiama gli altri metodi per far passare gli embedding di input attraverso ogni fase

torch.nn è stato importato come nn, torch.nn.functional è disponibile come F, e torch è anch'esso disponibile.

Questo esercizio fa parte del corso

Modelli Transformer con PyTorch

Visualizza il corso

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        seq_length = x.size(1)
        # Split the input embeddings and permute
        x = x.____
        return x.permute(0, 2, 1, 3)
Modifica ed esegui il codice