Problème de gradient décroissant
L'autre problème possible lié aux gradients est leur disparition ou leur passage à zéro. Il s'agit d'un problème beaucoup plus complexe à résoudre, car il est plus difficile à détecter. Si la fonction de perte ne s'améliore pas à chaque étape, est-ce parce que les gradients sont passés à zéro et n'ont donc pas mis à jour les poids ? Ou est-ce parce que le modèle n'est pas capable d'apprendre ?
Ce problème survient plus fréquemment dans les modèles RNN lorsque la mémoire est importante (phrases longues).
Dans cet exercice, vous observerez le problème sur les données IMDB, avec des phrases plus longues sélectionnées. Les données sont chargées dans les variables X
et y
, ainsi que dans les classes Sequential
, SimpleRNN
, Dense
et matplotlib.pyplot
sous le nom plt
. Le modèle a été pré-entraîné avec 100 époques, ses poids et son historique sont stockés dans le fichier model_weights.h5
et la variable history
.
Cet exercice fait partie du cours
Réseaux neuronaux récurrents (RNN) pour la modélisation du langage avec Keras
Instructions
- Ajoutez une couche d'
SimpleRNN
s au modèle. - Chargez les poids pré-entraînés sur le modèle à l'aide de la méthode «
.load_weights()
». - Ajoutez au graphique la précision des données d'apprentissage disponibles sur l'
'acc'
d'attribut. - Affichez le graphique à l'aide de la méthode
.show()
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Create the model
model = Sequential()
model.add(____(units=600, input_shape=(None, 1)))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
# Load pre-trained weights
model.____('model_weights.h5')
# Plot the accuracy x epoch graph
plt.plot(history.history[____])
plt.plot(history.history['val_acc'])
plt.legend(['train', 'val'], loc='upper left')
plt.____()