MulaiMulai sekarang secara gratis

Save and load a model

A manufacturing company wants to classify their projects based on images and determine the appropriate shipping packaging. Having trained a highly accurate model in PyTorch, you now plan to save the model and its pre-trained weights for future use and to share it with your team, making sure they can seamlessly load it.

torch and torch.nn as nn have been imported. The pre-trained model object is available in your workspace as model, and its architecture as ManufacturingCNN.

Latihan ini adalah bagian dari kursus

Deep Learning for Images with PyTorch

Lihat Kursus

Petunjuk latihan

  • Save the pre-trained model as ModelCNN.pth remembering to save the weights, not only the architecture.
  • Create a model instance called loaded_model from the class ManufacturingCNN().
  • Load ModelCNN.pth weights to loaded_model by passing the weights to .load_state_dict().

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

# Save the model
torch.____(model.____, ____)

# Create a new model
loaded_model = ____

# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)
Edit dan Jalankan Kode