Aan de slagGa gratis aan de slag

Positionele encodings maken

Tokens embedden is een goed begin, maar deze embeddings missen nog informatie over de positie van elke token in de reeks. Om dit op te lossen, maakt de transformerarchitectuur gebruik van positionele encodings. Daarmee codeer je positie-informatie van elke token in de embeddings.

Je maakt een PositionalEncoding-klasse met de volgende parameters:

  • d_model: de dimensionaliteit van de invoer-embeddings
  • max_seq_length: de maximale reekslengte (of de reekslengte als elke reeks even lang is)

Deze oefening maakt deel uit van de cursus

Transformermodels met PyTorch

Cursus bekijken

Oefeninstructies

  • Maak een nulmatrix met afmetingen max_seq_length bij d_model.
  • Voer de sinus- en cosinusberekeningen uit op position * div_term om de even en oneven positionele embedding-waarden te maken.
  • Zorg ervoor dat pe geen leerbare parameter is tijdens het trainen.
  • Tel de getransformeerde positionele embeddings op bij de token-embeddings van de invoer, x.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        # Create a matrix of zeros of dimensions max_seq_length by d_model
        pe = ____
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        # Perform the sine and cosine calculations
        pe[:, 0::2] = torch.____(position * div_term)
        pe[:, 1::2] = torch.____(position * div_term)
        # Ensure pe isn't a learnable parameter during training
        self.____('____', pe.unsqueeze(0))
        
    def forward(self, x):
        # Add the positional embeddings to the token embeddings
        return ____ + ____[:, :x.size(1)]

pos_encoding_layer = PositionalEncoding(d_model=512, max_seq_length=4)
output = pos_encoding_layer(token_embeddings)
print(output.shape)
print(output[0][0][:10])
Code bewerken en uitvoeren