Save and load a model
A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.
torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.
This exercise is part of the course
Deep Learning for Images with PyTorch
Exercise instructions
- Save the pre-trained model as
ModelCNN.pthremembering to save the weights, not only the architecture. - Create a model instance called
loaded_modelfrom the classManufacturingCNN(). - Load
ModelCNN.pthweights toloaded_modelby passing the weights to.load_state_dict().
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Save the model
torch.____(model.____, ____)
# Create a new model
loaded_model = ____
# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)