Costruire un modello RNN per il testo
Come data analyst in PyBooks, ti capita spesso di lavorare con insiemi di dati che contengono informazioni sequenziali, come interazioni con i clienti, serie temporali o documenti di testo. Le RNN possono analizzare efficacemente questo tipo di dati ed estrarne insight. In questo esercizio, esplorerai il dataset Newsgroup, già preprocessato e codificato per te. Questo insieme di dati comprende articoli di categorie diverse. Il tuo compito è applicare una RNN per classificare questi articoli in tre categorie:
rec.autos, sci.med e comp.graphics.
Sono già stati caricati per te: torch, nn, optim.
Inoltre, i parametri input_size, hidden_size (32), num_layers (2) e num_classes sono stati precaricati.
Questo e i prossimi esercizi usano il dataset fetch_20newsgroups da sklearn.
Questo esercizio fa parte del corso
Deep Learning per il testo con PyTorch
Istruzioni dell'esercizio
- Completa la classe RNN con un livello RNN e un livello Linear fully connected.
- Inizializza il modello.
- Addestra il modello RNN per dieci epoche azzerando i gradienti.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Complete the RNN class
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNNModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = ____.____(input_size, hidden_size, num_layers, batch_first=True)
self.fc = ____.____(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, _ = self.rnn(x, h0)
out = out[:, -1, :]
out = self.fc(out)
return out
# Initialize the model
rnn_model = ____(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_model.parameters(), lr=0.01)
# Train the model for ten epochs and zero the gradients
for epoch in ____:
optimizer.____()
outputs = ____(X_train_seq)
loss = criterion(outputs, y_train_seq)
loss.backward()
optimizer.step()
print(f'Epoch: {epoch+1}, Loss: {loss.item()}')