IniziaInizia gratis

Save and load a model

A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.

torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.

Questo esercizio fa parte del corso

Deep Learning for Images with PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Save the pre-trained model as ModelCNN.pth remembering to save the weights, not only the architecture.
  • Create a model instance called loaded_model from the class ManufacturingCNN().
  • Load ModelCNN.pth weights to loaded_model by passing the weights to .load_state_dict().

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Save the model
torch.____(model.____, ____)

# Create a new model
loaded_model = ____

# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)
Modifica ed esegui il codice