Réseau LSTM
Comme vous le savez déjà, les cellules RNN simples sont rarement utilisées en pratique. Une alternative bien plus répandue, qui gère bien mieux les longues séquences, est la cellule Long Short-Term Memory, ou LSTM. Dans cet exercice, vous allez construire vous-même un réseau LSTM !
La différence d’implémentation la plus importante par rapport au réseau RNN que vous avez construit précédemment vient du fait que les LSTM possèdent deux états cachés au lieu d’un. Vous devrez donc initialiser cet état caché supplémentaire et le transmettre à la cellule LSTM.
torch et torch.nn ont déjà été importés pour vous, alors lancez-vous !
Cet exercice fait partie du cours
Deep learning intermédiaire avec PyTorch
Instructions
- Dans la méthode
.__init__(), définissez une couche LSTM et assignez-la àself.lstm. - Dans la méthode
forward(), initialisez le premier état caché de la mémoire à long termec0avec des zéros. - Dans la méthode
forward(), passez les trois entrées à la coucheLSTM: les entrées au pas de temps courant, et un tuple contenant les deux états cachés.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class Net(nn.Module):
def __init__(self, input_size):
super().__init__()
# Define lstm layer
____ = ____(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
# Initialize long-term memory
c0 = ____
# Pass all inputs to lstm layer
out, _ = ____
out = self.fc(out[:, -1, :])
return out