Save and load a model
A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.
torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.
Bu egzersiz
Deep Learning for Images with PyTorch
kursunun bir parçasıdırEgzersiz talimatları
- Save the pre-trained model as
ModelCNN.pthremembering to save the weights, not only the architecture. - Create a model instance called
loaded_modelfrom the classManufacturingCNN(). - Load
ModelCNN.pthweights toloaded_modelby passing the weights to.load_state_dict().
Uygulamalı interaktif egzersiz
Bu örnek kodu tamamlayarak bu egzersizi bitirin.
# Save the model
torch.____(model.____, ____)
# Create a new model
loaded_model = ____
# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)