Rete GRU
Accanto alle LSTM, un'altra variante molto diffusa delle reti neurali ricorrenti è la Gated Recurrent Unit, o GRU. Il suo punto di forza è la semplicità: le celle GRU richiedono meno calcolo rispetto alle celle LSTM, pur offrendo spesso prestazioni comparabili.
Il codice fornito è la definizione del modello RNN che hai scritto in precedenza. Il tuo compito è adattarlo in modo che produca invece una rete GRU. torch e torch.nn as nn sono già stati importati per te.
Questo esercizio fa parte del corso
Deep Learning intermedio con PyTorch
Istruzioni dell'esercizio
- Aggiorna la definizione del modello RNN per ottenere una rete GRU; assegna il livello GRU a
self.gru.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
class Net(nn.Module):
def __init__(self):
super().__init__()
# Define RNN layer
self.rnn = nn.RNN(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
out, _ = self.rnn(x, h0)
out = self.fc(out[:, -1, :])
return out