GRU réseau
Outre les LSTM, une autre variante populaire de réseau neuronal récurrent est l'unité récurrente gérée (Gated Recurrent Unit), ou GRU. Son attrait réside dans sa simplicité : les cellules GRU nécessitent moins de calculs que les cellules LSTM tout en offrant des performances souvent équivalentes.
Le code qui vous est fourni est la définition du modèle RNN que vous avez codé précédemment. Votre tâche consiste à l'adapter pour qu'il produise un réseau GRU à la place. torch
et torch.nn as nn
ont déjà été importés pour vous.
Cet exercice fait partie du cours
Apprentissage profond intermédiaire avec PyTorch
Instructions
- Mettez à jour la définition du modèle RNN afin d'obtenir un réseau GRU; affectez la couche GRU à
self.gru
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class Net(nn.Module):
def __init__(self):
super().__init__()
# Define RNN layer
self.rnn = nn.RNN(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
out, _ = self.rnn(x, h0)
out = self.fc(out[:, -1, :])
return out