CommencerCommencer gratuitement

Effectuer plusieurs mises à jour des pondérations

Vous allez maintenant effectuer plusieurs mises à jour afin d'améliorer considérablement les poids de votre modèle et observer comment les prédictions s'améliorent à chaque mise à jour.

get_slope() Afin de maintenir la propreté de votre code, une fonction préchargée, « preloadunloadscripts », est disponible. Elle prend comme arguments les fichiers « input_data », « target » et « weights ». Il existe également une fonction d'get_mse() qui prend les mêmes arguments. Les fichiers input_data, target et weights ont été préchargés.

Ce réseau ne comporte aucune couche cachée et passe directement de l'entrée (avec 3 nœuds) à un nœud de sortie. Veuillez noter que « weights » est un tableau unique.

Nous avons également préchargé l'matplotlib.pyplot, et l'historique des erreurs sera représenté graphiquement une fois que vous aurez effectué vos étapes de descente de gradient.

Cet exercice fait partie du cours

Introduction au Deep Learning avec Python

Afficher le cours

Instructions

  • Utilisation d'une boucle « for » pour mettre à jour les poids de manière itérative :

    • Calculez la pente à l'aide de la fonction « get_slope() ».

    • Mettez à jour les poids en utilisant un taux d'apprentissage de 0.01.

    • Calculez l'erreur quadratique moyenne (mse) avec les poids mis à jour à l'aide de la fonction get_mse().

    • Veuillez ajouter mse à mse_hist.

  • Veuillez cliquer sur « Soumettre la réponse » pour visualiser l'mse_hist. Quelle tendance remarquez-vous ?

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

n_updates = 20
mse_hist = []

# Iterate over the number of updates
for i in range(n_updates):
    # Calculate the slope: slope
    slope = ____(____, ____, ____)
    
    # Update the weights: weights
    weights = ____ - ____ * ____
    
    # Calculate mse with new weights: mse
    mse = ____(____, ____, ____)
    
    # Append the mse to mse_hist
    ____

# Plot the mse history
plt.plot(mse_hist)
plt.xlabel('Iterations')
plt.ylabel('Mean Squared Error')
plt.show()
Modifier et exécuter le code