CommencerCommencer gratuitement

Réseau GRU

À côté des LSTM, une autre variante populaire des réseaux de neurones récurrents est le Gated Recurrent Unit, ou GRU. Son intérêt réside dans sa simplicité : les cellules GRU nécessitent moins de calcul que les cellules LSTM tout en offrant souvent des performances comparables.

Le code fourni correspond à la définition du modèle RNN que vous avez écrit précédemment. Votre tâche est de l’adapter pour obtenir un réseau GRU à la place. torch et torch.nn as nn ont déjà été importés pour vous.

Cet exercice fait partie du cours

Deep learning intermédiaire avec PyTorch

Afficher le cours

Instructions

  • Mettez à jour la définition du modèle RNN afin d’obtenir un réseau GRU ; affectez la couche GRU à self.gru.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Define RNN layer
        self.rnn = nn.RNN(
            input_size=1,
            hidden_size=32,
            num_layers=2,
            batch_first=True,
        )
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32)
        out, _ = self.rnn(x, h0)  
        out = self.fc(out[:, -1, :])
        return out
Modifier et exécuter le code