IniziaInizia gratis

Rete GRU

Accanto alle LSTM, un'altra variante molto diffusa delle reti neurali ricorrenti è la Gated Recurrent Unit, o GRU. Il suo punto di forza è la semplicità: le celle GRU richiedono meno calcolo rispetto alle celle LSTM, pur offrendo spesso prestazioni comparabili.

Il codice fornito è la definizione del modello RNN che hai scritto in precedenza. Il tuo compito è adattarlo in modo che produca invece una rete GRU. torch e torch.nn as nn sono già stati importati per te.

Questo esercizio fa parte del corso

Deep Learning intermedio con PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Aggiorna la definizione del modello RNN per ottenere una rete GRU; assegna il livello GRU a self.gru.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Define RNN layer
        self.rnn = nn.RNN(
            input_size=1,
            hidden_size=32,
            num_layers=2,
            batch_first=True,
        )
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32)
        out, _ = self.rnn(x, h0)  
        out = self.fc(out[:, -1, :])
        return out
Modifica ed esegui il codice