Réseau GRU
À côté des LSTM, une autre variante populaire des réseaux de neurones récurrents est le Gated Recurrent Unit, ou GRU. Son intérêt réside dans sa simplicité : les cellules GRU nécessitent moins de calcul que les cellules LSTM tout en offrant souvent des performances comparables.
Le code fourni correspond à la définition du modèle RNN que vous avez écrit précédemment. Votre tâche est de l’adapter pour obtenir un réseau GRU à la place. torch et torch.nn as nn ont déjà été importés pour vous.
Cet exercice fait partie du cours
Deep learning intermédiaire avec PyTorch
Instructions
- Mettez à jour la définition du modèle RNN afin d’obtenir un réseau GRU ; affectez la couche GRU à
self.gru.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class Net(nn.Module):
def __init__(self):
super().__init__()
# Define RNN layer
self.rnn = nn.RNN(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
out, _ = self.rnn(x, h0)
out = self.fc(out[:, -1, :])
return out