Starting the MultiHeadAttentionClass
Now that you've defined classes for creating token embeddings and positional embeddings, it's time to define a class for performing multi-head attention. To start, set up the parameters used for the attention calculation and the linear layers used for transforming the input embeddings into query, key, and value matrices, and one for projecting the combined attention weights back into embeddings.
torch.nn
has been imported as nn
.
This exercise is part of the course
Transformer Models with PyTorch
Exercise instructions
- Calculate the embedding dimensions each attention head will process,
head_dim
. - Define the three input layers (for query, key, and value) and one output layer; remove the bias parameter from the input layers.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
# Calculate the dimensions each head will process
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = ____
# Define the three input layers and one output layer
self.query_linear = nn.Linear(____, ____, bias=False)
self.key_linear = nn.Linear(____)
self.value_linear = ____
self.output_linear = ____