Ajouter des couches à un réseau
Vous avez vu comment expérimenter avec des réseaux plus larges. Dans cet exercice, vous allez essayer un réseau plus profond (plus de couches cachées).
Encore une fois, vous disposez d’un modèle de référence appelé model_1 comme point de départ. Il comporte 1 couche cachée avec 10 unités. Vous pouvez voir un récapitulatif de la structure de ce modèle. Vous allez créer un réseau similaire avec 3 couches cachées (en conservant 10 unités dans chaque couche).
L’ajustement des deux modèles prendra de nouveau un moment, vous devrez donc patienter quelques secondes pour voir les résultats après avoir exécuté votre code.
Cet exercice fait partie du cours
Introduction au Deep Learning en Python
Instructions
- Indiquez un modèle appelé
model_2semblable àmodel_1, mais avec 3 couches cachées de 10 unités au lieu d’une seule couche cachée.- Utilisez
input_shapepour spécifier la forme d’entrée dans la première couche cachée. - Utilisez l’activation
'relu'pour les 3 couches cachées et'softmax'pour la couche de sortie, qui doit avoir 2 unités.
- Utilisez
- Compilez
model_2comme pour les modèles précédents : utilisez'adam'commeoptimizer,'categorical_crossentropy'pour la perte, etmetrics=['accuracy']. - Cliquez sur "Soumettre la réponse" pour ajuster les deux modèles et visualiser lequel donne les meilleurs résultats !
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# The input shape to use in the first hidden layer
input_shape = (n_cols,)
# Create the new model: model_2
model_2 = ____
# Add the first, second, and third hidden layers
____
____
____
# Add the output layer
____
# Compile model_2
____
# Fit model 1
model_1_training = model_1.fit(predictors, target, epochs=15, validation_split=0.4, verbose=False)
# Fit model 2
model_2_training = model_2.fit(predictors, target, epochs=15, validation_split=0.4, verbose=False)
# Create the plot
plt.plot(model_1_training.history['val_loss'], 'r', model_2_training.history['val_loss'], 'b')
plt.xlabel('Epochs')
plt.ylabel('Validation score')
plt.show()