Aan de slagGa gratis aan de slag

Vanishing gradient-probleem

Het andere mogelijke gradientprobleem is wanneer de gradients verdwijnen, of naar nul gaan. Dit is een veel lastiger probleem om op te lossen, omdat het moeilijker te detecteren is. Als de verliesfunctie niet bij elke stap verbetert, komt dat dan doordat de gradients naar nul gingen en de gewichten dus niet zijn bijgewerkt? Of komt het omdat het model niet kan leren?

Dit probleem komt vaker voor in RNN-modellen wanneer lange geheugenafhankelijkheid nodig is (bij lange zinnen).

In deze oefening ga je het probleem bekijken op de IMDB-gegevens, waarbij langere zinnen zijn geselecteerd. De data is geladen in de variabelen X en y, en de klassen Sequential, SimpleRNN, Dense en matplotlib.pyplot als plt zijn beschikbaar. Het model is vooraf getraind met 100 epochs; de gewichten en de geschiedenis zijn opgeslagen in het bestand model_weights.h5 en de variabele history.

Deze oefening maakt deel uit van de cursus

Recurrent Neural Networks (RNN's) voor taalmodellen met Keras

Cursus bekijken

Oefeninstructies

  • Voeg een SimpleRNN-laag toe aan het model.
  • Laad de vooraf getrainde gewichten in het model met de methode .load_weights().
  • Voeg de nauwkeurigheid van de trainingsdata, beschikbaar onder het attribuut 'acc', toe aan de grafiek.
  • Toon de grafiek met de methode .show().

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Create the model
model = Sequential()
model.add(____(units=600, input_shape=(None, 1)))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])

# Load pre-trained weights
model.____('model_weights.h5')

# Plot the accuracy x epoch graph
plt.plot(history.history[____])
plt.plot(history.history['val_acc'])
plt.legend(['train', 'val'], loc='upper left')
plt.____()
Code bewerken en uitvoeren