1. FastAPI prediction with a pre-trained model
Now that we've reviewed how to build GET and POST endpoints with FastAPI, let's learn how to load a pre-trained model and run the application.
2. Setting up the environment
To create a FastAPI app for serving ML predictions, we need three key libraries: FastAPI for the API, uvicorn as the Asynchronous Server Gateway Interface or ASGI server for handling concurrent requests, and joblib to load pre-trained models.
Once we have these set up, we will create an instance of our FastAPI app, which serves as the foundation of our project.
3. Loading the pre-trained penguin classifier
To create a FastAPI app for serving ML predictions, we will use the joblib library to load pre-trained models.
The model is trained on the Palmer Penguins dataset and uses four input features: culmen_length, culmen_depth, flipper_length, and body_mass to predict the penguin species: Adelie, Chinstrap, or Gentoo.
We'll load our pre-trained "penguin_classifier" model saved in a pickle file. A pickle file is a binary format used for saving machine learning models or complex data objects.
4. Uvicorn
Uvicorn is the ASGI server that runs our FastAPI apps. It is built in Python for Python.
The command line syntax shown here allows us to specify the host and port.
The uvicorn.run Python method is also available if we want to run uvicorn from our own Python scripts.
5. Creating the prediction endpoint
Let's set up our prediction API endpoint.
We'll define a POST route using the decorator @app.post("/predict"). This decorator specifies that all incoming requests to this route will be handled by the predict function.
The function takes the four penguin measurements as inputs and returns the model's prediction using the predict() method.
Here, we use the POST method to accept data sent in the request body.
6. Running the application
Let's run our app!
We'll use the if name == "main": statement to ensure our script runs directly. By calling uvicorn.run(), we launch a uvicorn server on localhost (0.0.0.0) at port 8080. We save the code in a file called your_api_script.py and start the app with a Python command.
7. Testing the API
To test the API, we will use cURL, a command-line tool commonly used for API testing.
We'll break down the cURL command: the capital X POST specifies the HTTP method right before the URL. The default port for cURL is 80, but we can adjust if needed as shown here. The capital H flag sets the headers, and the lowercase d flag passes the data as JSON based on the headers.
The API then returns a response with the predicted penguin species and a confidence score.
8. Let's practice!
Time to practice!