MulaiMulai sekarang secara gratis

Membuat positional encoding

Membuat embedding untuk token adalah awal yang baik, tetapi embedding tersebut masih belum memuat informasi tentang posisi tiap token dalam urutan. Untuk mengatasinya, arsitektur transformer menggunakan positional encoding, yang menyandikan informasi posisi dari setiap token ke dalam embedding.

Anda akan membuat kelas PositionalEncoding dengan parameter berikut:

  • d_model: dimensi embedding masukan
  • max_seq_length: panjang urutan maksimum (atau panjang urutan jika setiap urutan memiliki panjang yang sama)

Latihan ini adalah bagian dari kursus

Model Transformer dengan PyTorch

Lihat Kursus

Petunjuk latihan

  • Buat matriks nol berukuran max_seq_length kali d_model.
  • Lakukan perhitungan sine dan cosine pada position * div_term untuk membuat nilai embedding posisi genap dan ganjil.
  • Pastikan pe bukan parameter yang dapat dipelajari selama pelatihan.
  • Tambahkan embedding posisi yang telah ditransformasikan ke embedding token masukan, x.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        # Create a matrix of zeros of dimensions max_seq_length by d_model
        pe = ____
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        # Perform the sine and cosine calculations
        pe[:, 0::2] = torch.____(position * div_term)
        pe[:, 1::2] = torch.____(position * div_term)
        # Ensure pe isn't a learnable parameter during training
        self.____('____', pe.unsqueeze(0))
        
    def forward(self, x):
        # Add the positional embeddings to the token embeddings
        return ____ + ____[:, :x.size(1)]

pos_encoding_layer = PositionalEncoding(d_model=512, max_seq_length=4)
output = pos_encoding_layer(token_embeddings)
print(output.shape)
print(output[0][0][:10])
Edit dan Jalankan Kode