Enregistrer et charger un modèle
Une entreprise manufacturière souhaite classer ses projets en fonction d'images et déterminer l'emballage d'expédition approprié. Après avoir formé un modèle très précis dans PyTorch, vous envisagez maintenant de sauvegarder le modèle et ses poids pré-entraînés pour une utilisation future et de le partager avec votre équipe, en vous assurant qu'elle pourra le charger sans difficulté.
torch
torch.nn
et nn
ont été importés. L'objet modèle pré-entraîné est disponible dans votre espace de travail sous le nom model
, et son architecture sous le nom ManufacturingCNN
.
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Enregistrez le modèle pré-entraîné sous le nom d'
ModelCNN.pth
. Veuillez vous assurer d'enregistrer les poids, et pas seulement l'architecture. - Créez une instance de modèle appelée «
loaded_model
» à partir de la classe «ManufacturingCNN()
». - Veuillez charger les poids d'
ModelCNN.pth
surloaded_model
en les transmettant à.load_state_dict()
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Save the model
torch.____(model.____, ____)
# Create a new model
loaded_model = ____
# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)