CommencerCommencer gratuitement

GRU réseau

Outre les LSTM, une autre variante populaire de réseau neuronal récurrent est l'unité récurrente gérée (Gated Recurrent Unit), ou GRU. Son attrait réside dans sa simplicité : les cellules GRU nécessitent moins de calculs que les cellules LSTM tout en offrant des performances souvent équivalentes.

Le code qui vous est fourni est la définition du modèle RNN que vous avez codé précédemment. Votre tâche consiste à l'adapter pour qu'il produise un réseau GRU à la place. torch et torch.nn as nn ont déjà été importés pour vous.

Cet exercice fait partie du cours

Apprentissage profond intermédiaire avec PyTorch

Afficher le cours

Instructions

  • Mettez à jour la définition du modèle RNN afin d'obtenir un réseau GRU; affectez la couche GRU à self.gru.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Define RNN layer
        self.rnn = nn.RNN(
            input_size=1,
            hidden_size=32,
            num_layers=2,
            batch_first=True,
        )
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32)
        out, _ = self.rnn(x, h0)  
        out = self.fc(out[:, -1, :])
        return out
Modifier et exécuter le code