Guardar y cargar un modelo
Una empresa manufacturera quiere clasificar sus proyectos en función de imágenes y determinar el embalaje de envío adecuado. Después de haber entrenado un modelo de alta precisión en PyTorch, ahora planeas guardar el modelo y sus pesos preentrenados para su uso futuro y compartirlos con tu equipo, asegurándote de que puedan cargarlos sin problemas.
torch
y torch.nn
como nn
se han importado. El objeto del modelo preentrenado está disponible en tu área de trabajo como model
, y su arquitectura como ManufacturingCNN
.
Este ejercicio forma parte del curso
Aprendizaje profundo para imágenes con PyTorch
Instrucciones del ejercicio
- Guarda el modelo preentrenado como
ModelCNN.pth
y recuerda guardar los pesos, no solo la arquitectura. - Crea una instancia de modelo llamada «
loaded_model
» a partir de la clase «ManufacturingCNN()
». - Carga los pesos de
ModelCNN.pth
enloaded_model
pasando los pesos a.load_state_dict()
.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
# Save the model
torch.____(model.____, ____)
# Create a new model
loaded_model = ____
# Load the saved model
loaded_model.____(torch.____('ModelCNN.pth'))
print(loaded_model)