Adding methods to the MultiHeadAttention class
In this exercise, you'll build the rest of the MultiHeadAttention
class from the ground up by defining four methods:
.split_heads()
: split and transform the input embeddings between the attention heads.compute_attention()
: calculate the scaled dot-product attention weights multiplied by the values matrix.combine_heads()
: transform the attention weights back into the same shape as the input embeddings,x
.forward()
: call the other methods to pass the input embeddings through each process
torch.nn
has been imported as nn
, torch.nn.functional
is available as F
, and torch
is also available.
This exercise is part of the course
Transformer Models with PyTorch
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model, bias=False)
self.key_linear = nn.Linear(d_model, d_model, bias=False)
self.value_linear = nn.Linear(d_model, d_model, bias=False)
self.output_linear = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
seq_length = x.size(1)
# Split the input embeddings and permute
x = x.____
return x.permute(0, 2, 1, 3)