CommencerCommencer gratuitement

Réseau LSTM

Comme vous le savez déjà, les cellules RNN simples sont rarement utilisées en pratique. Une alternative bien plus répandue, qui gère bien mieux les longues séquences, est la cellule Long Short-Term Memory, ou LSTM. Dans cet exercice, vous allez construire vous-même un réseau LSTM !

La différence d’implémentation la plus importante par rapport au réseau RNN que vous avez construit précédemment vient du fait que les LSTM possèdent deux états cachés au lieu d’un. Vous devrez donc initialiser cet état caché supplémentaire et le transmettre à la cellule LSTM.

torch et torch.nn ont déjà été importés pour vous, alors lancez-vous !

Cet exercice fait partie du cours

Deep learning intermédiaire avec PyTorch

Afficher le cours

Instructions

  • Dans la méthode .__init__(), définissez une couche LSTM et assignez-la à self.lstm.
  • Dans la méthode forward(), initialisez le premier état caché de la mémoire à long terme c0 avec des zéros.
  • Dans la méthode forward(), passez les trois entrées à la couche LSTM : les entrées au pas de temps courant, et un tuple contenant les deux états cachés.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        # Define lstm layer
        ____ = ____(
            input_size=1,
            hidden_size=32,
            num_layers=2,
            batch_first=True,
        )
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32)
        # Initialize long-term memory
        c0 = ____
        # Pass all inputs to lstm layer
        out, _ = ____
        out = self.fc(out[:, -1, :])
        return out
Modifier et exécuter le code