IniziaInizia gratis

Creare un modello transformer

A PyBooks, il motore di raccomandazione su cui stai lavorando ha bisogno di capacità più raffinate per comprendere il sentiment delle recensioni degli utenti. Ritieni che l’uso dei transformer, un’architettura allo stato dell’arte, possa aiutare a raggiungere questo obiettivo. Decidi quindi di costruire un modello transformer che sappia codificare il sentiment nelle recensioni per dare il via al progetto.

I seguenti pacchetti sono già stati importati per te: torch, nn, optim.

I dati di input contengono frasi come: "I love this product", "This is terrible", "Could be better" … e le rispettive etichette binarie di sentiment come: 1, 0, 0, ...

I dati di input sono suddivisi e convertiti in embedding nelle seguenti variabili: train_sentences, train_labels, test_sentences, test_labels, token_embeddings

Questo esercizio fa parte del corso

Deep Learning per il testo con PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Inizializza il transformer encoder.
  • Definisci il livello fully connected in base al numero di classi di sentiment.
  • Nel metodo forward, fai passare l’input attraverso il transformer encoder seguito dal livello lineare.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

class TransformerEncoder(nn.Module):
    def __init__(self, embed_size, heads, num_layers, dropout):
        super(TransformerEncoder, self).__init__()
        # Initialize the encoder 
        self.encoder = nn.____(
            nn.____(d_model=embed_size, nhead=heads),
            num_layers=num_layers)
        # Define the fully connected layer
        self.fc = nn.Linear(embed_size, ____)

    def forward(self, x):
        # Pass the input through the transformer encoder 
        x = self.____(x)
        x = x.mean(dim=1) 
        return self.fc(x)

model = TransformerEncoder(embed_size=512, heads=8, num_layers=3, dropout=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
Modifica ed esegui il codice