Red LSTM
Como ya sabes, las celdas simples RNN no se utilizan mucho en la práctica. Una alternativa más utilizada y que garantiza un manejo mucho mejor de las secuencias largas son las celdas de memoria a corto plazo, o LSTM. En este ejercicio, vas a construir una red LSTM.
La diferencia de implementación más importante respecto a la red RNN que has construido anteriormente proviene del hecho de que las LSTM tienen dos estados ocultos en lugar de uno. Esto significa que tendrás que inicializar este estado oculto adicional y pasarlo a la celda LSTM.
torch
y torch.nn
ya se han importado para ti, ¡así que empieza a codificar!
Este ejercicio forma parte del curso
Aprendizaje profundo intermedio con PyTorch
Instrucciones del ejercicio
- En el método
.__init__()
, define una capa LSTM y asígnala aself.lstm
. - En el método
forward()
, inicializa el primer estado oculto de la memoria a largo plazoc0
con ceros. - En el método
forward()
, pasa las tres entradas a la capaLSTM
: las entradas del paso de tiempo actual y una tupla que contenga los dos estados ocultos.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
class Net(nn.Module):
def __init__(self, input_size):
super().__init__()
# Define lstm layer
____ = ____(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
# Initialize long-term memory
c0 = ____
# Pass all inputs to lstm layer
out, _ = ____
out = self.fc(out[:, -1, :])
return out