Aan de slagGa gratis aan de slag

Creating positional encodings

Embedding the tokens is a good start, but these embeddings still lack information about each token's position in the sequence. To remedy this, the transformer architecture makes use of positional encodings. This encodes positional information from each token into the embeddings.

You'll create a PositionalEncoding class with the following parameters:

  • d_model: the dimensionality of the input embeddings
  • max_seq_length: the maximum sequence length (or the sequence length if each sequence is the same length)

Deze oefening maakt deel uit van de cursus

Transformer Models with PyTorch

Cursus bekijken

Oefeninstructies

  • Create a matrix of zeros of dimensions max_seq_length by d_model.
  • Perform the sine and cosine calculations on position * div_term to create the even and odd positional embedding values.
  • Ensure pe isn't a learnable parameter during training.
  • Add the transformed positional embeddings to the input token embeddings, x.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        # Create a matrix of zeros of dimensions max_seq_length by d_model
        pe = ____
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        # Perform the sine and cosine calculations
        pe[:, 0::2] = torch.____(position * div_term)
        pe[:, 1::2] = torch.____(position * div_term)
        # Ensure pe isn't a learnable parameter during training
        self.____('____', pe.unsqueeze(0))
        
    def forward(self, x):
        # Add the positional embeddings to the token embeddings
        return ____ + ____[:, :x.size(1)]

pos_encoding_layer = PositionalEncoding(d_model=512, max_seq_length=4)
output = pos_encoding_layer(token_embeddings)
print(output.shape)
print(output[0][0][:10])
Code bewerken en uitvoeren