Aan de slagGa gratis aan de slag

Save and load a model

A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.

torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.

Deze oefening maakt deel uit van de cursus

Deep Learning for Images with PyTorch

Cursus bekijken

Oefeninstructies

  • Save the pre-trained model as ModelCNN.pth remembering to save the weights, not only the architecture.
  • Create a model instance called loaded_model from the class ManufacturingCNN().
  • Load ModelCNN.pth weights to loaded_model by passing the weights to .load_state_dict().

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Save the model
torch.____(model.____, ____)

# Create a new model
loaded_model = ____

# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)
Code bewerken en uitvoeren