Building an LSTM model for text
At PyBooks, the team is constantly seeking to enhance the user experience by leveraging the latest advancements in technology. In line with this vision, they have assigned you a critical task. The team wants you to explore the potential of another powerful tool: LSTM, known for capturing more complexities in data patterns. You are working with the same Newsgroup dataset, with the objective remaining unchanged: to classify news articles into three distinct categories:
rec.autos
, sci.med
, and comp.graphics
.
The following packages have been loaded for you: torch
, nn
, optim
.
Este ejercicio forma parte del curso
Deep Learning for Text with PyTorch
Instrucciones del ejercicio
- Set up an LSTM model by completing the LSTM and linear layers with the necessary parameters.
- Initialize the model with the necessary parameters.
- Train the LSTM model resetting the gradients to zero and passing the input data
X_train_seq
through the model. - Calculate the loss based on the predicted
outputs
and the true labels.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
# Initialize the LSTM and the output layer with parameters
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(____, ____, ____, batch_first=True)
self.fc = nn.Linear(____, ____)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, _ = self.lstm(x, (h0, c0))
out = out[:, -1, :]
out = self.fc(out)
return out
# Initialize model with required parameters
lstm_model = LSTMModel(____, ____, ____, ____)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(lstm_model.parameters(), lr=0.01)
# Train the model by passing the correct parameters and zeroing the gradient
for epoch in range(10):
optimizer.____
outputs = lstm_model(____)
loss = criterion(____, y_train_seq)
loss.backward()
optimizer.step()
print(f'Epoch: {epoch+1}, Loss: {loss.item()}')