Save and load a model
A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.
torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.
Questo esercizio fa parte del corso
Deep Learning for Images with PyTorch
Istruzioni dell'esercizio
- Save the pre-trained model as
ModelCNN.pthremembering to save the weights, not only the architecture. - Create a model instance called
loaded_modelfrom the classManufacturingCNN(). - Load
ModelCNN.pthweights toloaded_modelby passing the weights to.load_state_dict().
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Save the model
torch.____(model.____, ____)
# Create a new model
loaded_model = ____
# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)