1. Learn
  2. /
  3. Courses
  4. /
  5. Deep Learning for Images with PyTorch

Connected

Exercise

Save and load a model

A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.

torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.

Instructions

100 XP
  • Save the pre-trained model as ModelCNN.pth remembering to save the weights, not only the architecture.
  • Create a model instance called loaded_model from the class ManufacturingCNN().
  • Load ModelCNN.pth weights to loaded_model by passing the weights to .load_state_dict().