Save and load a model
A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.
torch
and torch.nn
as nn
have been imported. The pre-trained model object is available in your workspace as model
, and its architecture as ManufacturingCNN
.
This exercise is part of the course
Deep Learning for Images with PyTorch
Exercise instructions
- Save the pre-trained model as
ModelCNN.pth
remembering to save the weights, not only the architecture. - Create a model instance called
loaded_model
from the classManufacturingCNN()
. - Load
ModelCNN.pth
weights toloaded_model
by passing the weights to.load_state_dict()
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Save the model
torch.____(model.____, ____)
# Create a new model
loaded_model = ____
# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)