CommencerCommencer gratuitement

Problème de gradient décroissant

L'autre problème possible lié aux gradients est leur disparition ou leur passage à zéro. Il s'agit d'un problème beaucoup plus complexe à résoudre, car il est plus difficile à détecter. Si la fonction de perte ne s'améliore pas à chaque étape, est-ce parce que les gradients sont passés à zéro et n'ont donc pas mis à jour les poids ? Ou est-ce parce que le modèle n'est pas capable d'apprendre ?

Ce problème survient plus fréquemment dans les modèles RNN lorsque la mémoire est importante (phrases longues).

Dans cet exercice, vous observerez le problème sur les données IMDB, avec des phrases plus longues sélectionnées. Les données sont chargées dans les variables X et y, ainsi que dans les classes Sequential, SimpleRNN, Dense et matplotlib.pyplot sous le nom plt. Le modèle a été pré-entraîné avec 100 époques, ses poids et son historique sont stockés dans le fichier model_weights.h5 et la variable history.

Cet exercice fait partie du cours

Réseaux neuronaux récurrents (RNN) pour la modélisation du langage avec Keras

Afficher le cours

Instructions

  • Ajoutez une couche d'SimpleRNN s au modèle.
  • Chargez les poids pré-entraînés sur le modèle à l'aide de la méthode « .load_weights() ».
  • Ajoutez au graphique la précision des données d'apprentissage disponibles sur l''acc' d'attribut.
  • Affichez le graphique à l'aide de la méthode .show().

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Create the model
model = Sequential()
model.add(____(units=600, input_shape=(None, 1)))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])

# Load pre-trained weights
model.____('model_weights.h5')

# Plot the accuracy x epoch graph
plt.plot(history.history[____])
plt.plot(history.history['val_acc'])
plt.legend(['train', 'val'], loc='upper left')
plt.____()
Modifier et exécuter le code