BaşlayınÜcretsiz Başlayın

Konumsal kodlamalar oluşturma

Token’ları gömmek iyi bir başlangıç, ancak bu gömmelerde hâlâ dizideki her token’ın konumu hakkında bilgi yok. Bunu gidermek için transformer mimarisi konumsal kodlamalardan yararlanır. Bu sayede her token’daki konumsal bilgi gömmelere işlenir.

Aşağıdaki parametrelere sahip bir PositionalEncoding sınıfı oluşturacaksın:

  • d_model: girdi gömmelerinin boyutu
  • max_seq_length: en büyük dizi uzunluğu (veya her dizi aynı uzunluktaysa dizi uzunluğu)

Bu egzersiz

PyTorch ile Transformer Modelleri

kursunun bir parçasıdır
Kursu Görüntüle

Egzersiz talimatları

  • Boyutları max_seq_length çarpı d_model olan sıfırlardan bir matris oluştur.
  • Çift ve tek konumsal gömme değerlerini oluşturmak için position * div_term üzerinde sinüs ve kosinüs hesaplarını yap.
  • Eğitim sırasında pe’nin öğrenilebilir bir parametre olmadığından emin ol.
  • Dönüştürülmüş konumsal gömmeleri, girdi token gömmelerine (x) ekle.

Uygulamalı interaktif egzersiz

Bu örnek kodu tamamlayarak bu egzersizi bitirin.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        # Create a matrix of zeros of dimensions max_seq_length by d_model
        pe = ____
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        # Perform the sine and cosine calculations
        pe[:, 0::2] = torch.____(position * div_term)
        pe[:, 1::2] = torch.____(position * div_term)
        # Ensure pe isn't a learnable parameter during training
        self.____('____', pe.unsqueeze(0))
        
    def forward(self, x):
        # Add the positional embeddings to the token embeddings
        return ____ + ____[:, :x.size(1)]

pos_encoding_layer = PositionalEncoding(d_model=512, max_seq_length=4)
output = pos_encoding_layer(token_embeddings)
print(output.shape)
print(output[0][0][:10])
Kodu Düzenle ve Çalıştır