Aggiungere metodi alla classe MultiHeadAttention
In questo esercizio costruirai da zero il resto della classe MultiHeadAttention definendo quattro metodi:
.split_heads(): suddivide e trasforma gli embedding di input tra le teste di attention.compute_attention(): calcola i pesi della scaled dot-product attention moltiplicati per la matrice dei values.combine_heads(): riconverte i pesi dell'attenzione alla stessa forma degli embedding di input,x.forward(): richiama gli altri metodi per far passare gli embedding di input attraverso ogni fase
torch.nn è stato importato come nn, torch.nn.functional è disponibile come F, e torch è anch'esso disponibile.
Questo esercizio fa parte del corso
Modelli Transformer con PyTorch
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model, bias=False)
self.key_linear = nn.Linear(d_model, d_model, bias=False)
self.value_linear = nn.Linear(d_model, d_model, bias=False)
self.output_linear = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
seq_length = x.size(1)
# Split the input embeddings and permute
x = x.____
return x.permute(0, 2, 1, 3)