Get startedGet started for free

Multi-head self-attention

1. Multi-head self-attention

Now that we've explored positional encoding, we'll look at attention mechanisms.

2. Multi-head attention in transformers

Attention mechanisms feature prominently in the transformer's encoder and decoder blocks.

3. Self-attention mechanism

Self-attention enables transformers to identify relationships between tokens and focus on the most relevant ones for the task. Given a sequence of token embeddings,

4. Self-attention mechanism

each embedding is projected onto three matrices of equal dimensions - query, key, and values - using separate linear transformations with learned weights. Query indicates what each token is "looking for" in other tokens; key represents the "content" of each token that other tokens might find relevant, and values contains the actual content to be aggregated or weighted, based on the attention scores. Transforming each token's embeddings into these roles helps the model learn more nuanced token relationships.

5. Self-attention mechanism

To obtain a matrix of attention scores between tokens, we compute the similarity between Q and K, typically using the dot-product metric.

6. Self-attention mechanism

Attention weights are calculated by applying softmax to the attention scores. The attention weights reflect the relevance or attention the model assigns to each token in a sequence. In the sequence,

7. Self-attention mechanism

"orange is my favorite fruit," the tokens "favorite" and "fruit" receive the highest attention when processing "orange," as they directly influence its context and meaning. The model interprets "orange" as a favored fruit rather than a color or other meaning.

8. Self-attention mechanism

Finally, we multiply the values matrix, which are the token embeddings, by the attention weights

9. Self-attention mechanism

to update the embeddings with the self-attention information. This self-attention mechanism uses one attention head with one set of Q, K, and V matrices.

10. Multi-head attention

In practice, embeddings are split between multiple attention heads to focus on different aspects of the sequence in parallel,

11. Multi-head attention

like relationships, sentiment, or subject.

12. Multi-head attention

Multi-head attention concatenates attention-head outputs, linearly transforming them to match the input dimensions. The resulting embeddings capture token meaning, positional encoding, and contextual relationships.

13. MultiHeadAttention class

Let's define a MultiHeadAttention class to put this into action, again, based on nn.Module. In the __init__ method, num_heads is the number of attention heads, and head_dim is the embedding dimensions each head will process. d_model must be divisible by num_heads. Three linear layers are defined for the attention inputs: query, key, and value, and one for the final concatenated output. Using bias=False for query, key, and value layers eliminates the bias term, reducing model parameters without impacting the ability to capture relationships.

14. MultiHeadAttention class

We'll also define three helper methods for the different processes in the mechanism. split_heads splits the query, key, and values tensors between the heads and transforms them into shape: batch_size, seq_length, num_heads, head_dim. Inside compute_attention, torch.matmul calculates the dot product between the query and key matrices, which requires transposing the key matrix, and calculates the attention weights inside each head using softmax. We also add a condition to allow for masking, which we'll discuss more in the next chapter. Finally, we return these attention_weights multiplied by the value matrix. combine_heads transforms the attention weights back into the original embedding shape.

15. MultiHeadAttention class

In the forward method, we split the query, key, and value tensors across the heads and compute the attention weights. The weights are combined and passed through the output layer to obtain the updated token embeddings projected into the original dimensionality.

16. Let's practice!

Time to define your own attention mechanism!

Create Your Free Account

or

By continuing, you accept our Terms of Use, our Privacy Policy and that your data is stored in the USA.