Get startedGet started for free

Exporting models with TorchScript

1. Exporting models with TorchScript

Welcome! In this video, we'll learn how to export PyTorch models with TorchScript—a format designed for efficient deployment.

2. What is TorchScript?

Unlike standard PyTorch models, TorchScript models don't need Python, making them ideal for high-performance production systems. Common use cases include deploying on mobile devices

3. What is TorchScript?

or embedding in production environments where Python may not be available.

4. Converting models to TorchScript

There are two primary methods to convert PyTorch models into TorchScript: `torch.jit.trace` and `torch.jit.script`. The `trace` method works by tracing the model's execution using example inputs and is ideal for simpler, feedforward models. However, it misses dynamic behavior like loops and conditionals, which may cause incorrect outputs in production. On the other hand, `script` compiles the entire model by analyzing its code and is better suited for models that include control flow, such as loops. First, we import `torch` and `torch.nn` to define a simple model. This model has a `forward` method that multiplies the input by 2. Once the model is defined, we instantiate it as `SimpleModel()`. Finally, we use `torch.jit.script` to convert the model into TorchScript, which prepares it for deployment. This exported scripted model can now be saved and used efficiently in production workflows.

5. Saving and loading TorchScript models

Once a model is converted to TorchScript, the next step is to save it for deployment. The function `torch.jit.save` allows us to save the scripted model to a file. Later, we can use `torch.jit.load` to reload the model for inference. The example here shows how we save a TorchScript model to a file named `model.pt` and then reload it. The ability to save and load models makes TorchScript invaluable for production workflows.

6. Performing inference with TorchScript

The final step in the TorchScript workflow is inference. After loading the model, we can pass input tensors to it for predictions. The outputs are identical to what we would expect from the original PyTorch model. For example, with an input tensor of `[1.0, 2.0, 3.0]`, the output tensor would be `[2.0, 4.0, 6.0]`. This consistency ensures that exporting a model does not compromise its accuracy or functionality.

7. TorchScript in a nutshell

Let's recap the syntax we've covered. `torch.jit.trace` is perfect for static, feedforward models without control flow. Use it when your model has a straightforward execution path. On the other hand, `torch.jit.script` is more powerful, as it compiles Python code, including any control flow like loops or conditionals, making it ideal for complex models. When deploying, you'll need `torch.jit.save` to store your scripted model as a file. This file is portable and optimized for deployment. Later, you can use `torch.jit.load` to reload the model and perform inference. Together, these functions streamline the transition from development to production.

8. Let's practice!

Now, let's move on to exercises where you'll get to practice TorchScript skills and reinforce your understanding of model conversion, saving and performing.