CommencerCommencer gratuitement

Problème de gradient qui s'annule

L’autre problème de gradient possible est lorsque les gradients s’annulent, ou tendent vers zéro. C’est un problème bien plus difficile à résoudre, car il est moins évident à détecter. Si la fonction de perte ne s’améliore pas à chaque étape, est-ce parce que les gradients sont tombés à zéro et n’ont donc pas mis à jour les poids ? Ou est-ce parce que le modèle n’arrive pas à apprendre ?

Ce problème survient plus souvent dans les modèles RNN lorsque la mémoire longue est nécessaire (phrases très longues).

Dans cet exercice, vous allez observer ce phénomène sur les données IMDB, avec des phrases plus longues sélectionnées. Les données sont chargées dans les variables X et y, ainsi que les classes Sequential, SimpleRNN, Dense et matplotlib.pyplot sous le nom plt. Le modèle a été pré-entraîné pendant 100 époques ; ses poids et son historique sont enregistrés dans le fichier model_weights.h5 et la variable history.

Cet exercice fait partie du cours

Réseaux de neurones récurrents (RNN) pour la modélisation du langage avec Keras

Afficher le cours

Instructions

  • Ajoutez une couche SimpleRNN au modèle.
  • Chargez les poids pré-entraînés dans le modèle avec la méthode .load_weights().
  • Ajoutez au graphique la précision des données d’entraînement disponible dans l’attribut 'acc'.
  • Affichez le graphique avec la méthode .show().

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Create the model
model = Sequential()
model.add(____(units=600, input_shape=(None, 1)))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])

# Load pre-trained weights
model.____('model_weights.h5')

# Plot the accuracy x epoch graph
plt.plot(history.history[____])
plt.plot(history.history['val_acc'])
plt.legend(['train', 'val'], loc='upper left')
plt.____()
Modifier et exécuter le code