Building an RNN model for text
As a data analyst at PyBooks, you often encounter datasets that contain sequential information, such as customer interactions, time series data, or text documents. RNNs can effectively analyze and extract insights from such data. In this exercise, you will dive into the Newsgroup dataset that has already been processed and encoded for you. This dataset comprises articles from different categories. Your task is to apply an RNN to classify these articles into three categories:
rec.autos
, sci.med
, and comp.graphics
.
The following has been loaded for you: torch
, nn
, optim
.
Additionally, the parameters input_size
, hidden_size
(32), num_layers
(2), and num_classes
have been preloaded for you.
This and the following exercises use the fetch_20newsgroups
dataset from sklearn
.
This exercise is part of the course
Deep Learning for Text with PyTorch
Exercise instructions
- Complete the RNN class with an RNN layer and a fully connected linear layer.
- Initialize the model.
- Train the RNN model for ten epochs by zeroing the gradients.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Complete the RNN class
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNNModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = ____.____(input_size, hidden_size, num_layers, batch_first=True)
self.fc = ____.____(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, _ = self.rnn(x, h0)
out = out[:, -1, :]
out = self.fc(out)
return out
# Initialize the model
rnn_model = ____(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_model.parameters(), lr=0.01)
# Train the model for ten epochs and zero the gradients
for epoch in ____:
optimizer.____()
outputs = ____(X_train_seq)
loss = criterion(outputs, y_train_seq)
loss.backward()
optimizer.step()
print(f'Epoch: {epoch+1}, Loss: {loss.item()}')