Serving a Bayesian Model as an API with FastAPI
You built a Bayesian forecasting model with PyMC. It samples beautifully in a notebook. The posterior distributions look great. Now what? At some point, that model needs to leave the notebook and start serving predictions to other systems. This post walks through one practical way to do that: wrapping a fitted PyMC model in a FastAPI application so that other services can request forecasts over HTTP.
Why FastAPI?
If you’ve served ML models before, you’ve probably seen Flask-based solutions. Flask works fine, but FastAPI gives us a few things that matter when serving probabilistic models:
- Automatic request validation via Pydantic, so we don’t have to write boilerplate to check inputs.
- Async support out of the box, which becomes relevant when sampling from the posterior can take a non-trivial amount of time.
- Auto-generated docs at
/docs, which makes it easy for consumers of your API to understand what it expects and returns.
None of this is revolutionary, but it removes friction. And when you are trying to get a model into production, removing friction is the whole game.
The model
For this post, we’ll build a simple Bayesian time series model: a linear trend with seasonality. Think of it as a lightweight version of what you might use to forecast daily revenue, weekly active users, or any metric that has a trend and repeats on a cycle.
Let’s start with generating some synthetic data and fitting the model:
import numpy as np
import pymc as pm
import arviz as az
# Generate synthetic daily data with trend + weekly seasonality
rng = np.random.default_rng(42)
n_days = 365
t = np.arange(n_days, dtype=float)
trend = 0.05 * t
seasonality = 3 * np.sin(2 * np.pi * t / 7)
noise = rng.normal(0, 1, n_days)
y = 10 + trend + seasonality + noise
Now we define and fit the PyMC model:
with pm.Model() as forecast_model:
# Priors
intercept = pm.Normal("intercept", mu=0, sigma=10)
slope = pm.Normal("slope", mu=0, sigma=1)
amp = pm.HalfNormal("amplitude", sigma=5)
period = pm.Normal("period", mu=7, sigma=0.5)
sigma = pm.HalfNormal("sigma", sigma=2)
# Deterministic trend + seasonality
mu = intercept + slope * t + amp * pm.math.sin(2 * np.pi * t / period)
# Likelihood
obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y)
# Sample
trace = pm.sample(2000, tune=1000, cores=2, random_seed=42)
After sampling, we can verify that our model has converged by checking the trace summary:
az.summary(trace, var_names=["intercept", "slope", "amplitude", "period", "sigma"])
At this point, we have a fitted model and its posterior trace. The next challenge is making this available as an API.
Saving the model artifacts
Before we build the API, we need to persist the trace so the API server can load it without re-fitting. ArviZ’s InferenceData format and PyMC’s model serialization make this straightforward:
import pickle
# Save the trace
trace.to_netcdf("model_artifacts/trace.nc")
# Save the training data bounds (we need these for forecasting)
artifacts = {
"n_train": n_days,
"t_train": t,
"y_train": y,
}
with open("model_artifacts/metadata.pkl", "wb") as f:
pickle.dump(artifacts, f)
We save the trace as a NetCDF file (ArviZ’s native format) and the training metadata separately. The metadata includes information like how many training days we have, which we need to construct the time index for future predictions.
Building the FastAPI application
Now for the core of the post. Our API needs to do three things:
- Load the fitted model artifacts on startup.
- Accept a request specifying how many days ahead to forecast.
- Return posterior predictive summaries (mean, credible intervals) for each future time step.
Let’s start with the project structure:
forecast_api/
app.py
model.py
schemas.py
model_artifacts/
trace.nc
metadata.pkl
Defining the request and response schemas
We use Pydantic models to define what the API expects and returns. This is one of FastAPI’s strengths: these schemas double as documentation and validation.
# schemas.py
from pydantic import BaseModel, Field
class ForecastRequest(BaseModel):
horizon: int = Field(
..., gt=0, le=365, description="Number of days ahead to forecast"
)
credible_interval: float = Field(
default=0.94, gt=0, lt=1, description="Width of the credible interval"
)
class ForecastPoint(BaseModel):
day: int
mean: float
lower: float
upper: float
class ForecastResponse(BaseModel):
horizon: int
credible_interval: float
forecasts: list[ForecastPoint]
Nothing fancy here. The request takes a horizon (how many days to forecast) and an optional credible_interval (defaulting to 94%, a common choice in Bayesian analysis). The response returns a list of forecast points, each with a mean and credible interval bounds.
The model wrapper
Next, we create a class that encapsulates the model loading and prediction logic. This keeps the FastAPI route handlers clean and makes the model logic testable independently.
# model.py
import pickle
import numpy as np
import arviz as az
class BayesianForecaster:
def __init__(self, trace_path: str, metadata_path: str):
self.trace = az.from_netcdf(trace_path)
with open(metadata_path, "rb") as f:
self.metadata = pickle.load(f)
def predict(self, horizon: int, credible_interval: float = 0.94):
posterior = self.trace.posterior
# Extract posterior samples (flatten chains)
intercept = posterior["intercept"].values.flatten()
slope = posterior["slope"].values.flatten()
amplitude = posterior["amplitude"].values.flatten()
period = posterior["period"].values.flatten()
sigma = posterior["sigma"].values.flatten()
n_train = self.metadata["n_train"]
future_t = np.arange(n_train, n_train + horizon, dtype=float)
# Generate posterior predictive samples for each future time step
rng = np.random.default_rng(0)
n_samples = len(intercept)
predictions = np.zeros((n_samples, horizon))
for i, t_val in enumerate(future_t):
mu = intercept + slope * t_val + amplitude * np.sin(2 * np.pi * t_val / period)
predictions[:, i] = rng.normal(mu, sigma)
# Summarize
alpha = 1 - credible_interval
results = []
for i in range(horizon):
samples = predictions[:, i]
results.append({
"day": int(future_t[i]),
"mean": float(np.mean(samples)),
"lower": float(np.percentile(samples, 100 * alpha / 2)),
"upper": float(np.percentile(samples, 100 * (1 - alpha / 2))),
})
return results
A few things to note here:
- We flatten the chains from the posterior so we get a single array of samples per parameter. This is fine for generating predictions; we’ve already checked convergence during fitting.
- For each future time step, we compute
muusing all posterior samples and then draw from the observation noise. This gives us full posterior predictive samples, not just the mean prediction. - We summarize by computing the mean and the percentiles corresponding to the requested credible interval.
The FastAPI app
Finally, we wire it all together:
# app.py
from contextlib import asynccontextmanager
from fastapi import FastAPI
from schemas import ForecastRequest, ForecastResponse, ForecastPoint
from model import BayesianForecaster
forecaster = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global forecaster
forecaster = BayesianForecaster(
trace_path="model_artifacts/trace.nc",
metadata_path="model_artifacts/metadata.pkl",
)
yield
app = FastAPI(
title="Bayesian Forecast API",
description="Serves forecasts from a PyMC time series model",
lifespan=lifespan,
)
@app.post("/forecast", response_model=ForecastResponse)
async def forecast(request: ForecastRequest):
results = forecaster.predict(
horizon=request.horizon,
credible_interval=request.credible_interval,
)
return ForecastResponse(
horizon=request.horizon,
credible_interval=request.credible_interval,
forecasts=[ForecastPoint(**r) for r in results],
)
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": forecaster is not None}
We use FastAPI’s lifespan context manager to load the model artifacts once at startup. This way, the model is loaded into memory when the server starts and stays there for the lifetime of the application. No re-loading on every request.
The /forecast endpoint accepts a POST request with the forecast parameters and returns the predictions. The /health endpoint is a simple check that the model is loaded, which is useful for load balancers and container orchestration.
Running the API
Start the server with:
uvicorn app:app --host 0.0.0.0 --port 8000
And test it:
curl -X POST http://localhost:8000/forecast \
-H "Content-Type: application/json" \
-d '{"horizon": 7, "credible_interval": 0.94}'
You should get back something like:
{
"horizon": 7,
"credible_interval": 0.94,
"forecasts": [
{"day": 365, "mean": 28.31, "lower": 26.12, "upper": 30.48},
{"day": 366, "mean": 29.42, "lower": 27.25, "upper": 31.63},
...
]
}
Each forecast point includes the mean prediction and the credible interval bounds. This is one of the advantages of serving a Bayesian model: you don’t just get a point estimate, you get a full uncertainty range that downstream systems can use to make better decisions.
You can also visit http://localhost:8000/docs to see the auto-generated interactive documentation where you can test the endpoint directly from your browser.
Considerations for production
This setup works well for getting a Bayesian model off your laptop and into a service. But there are a few things worth thinking about before calling it production-ready:
- Model updates: Right now, the model artifacts are loaded from disk at startup. If you retrain the model, you need to restart the server. A more robust approach is to load artifacts from an object store (S3, GCS) and implement a reload mechanism, either on a schedule or triggered by a webhook.
- Response times: Generating posterior predictive samples is not instant, especially with large traces or long horizons. If latency matters, consider pre-computing predictions for common horizons or caching results.
- Concurrency: The
predictmethod does NumPy operations that release the GIL, soasyncwith uvicorn workers handles this reasonably well. For heavier workloads, you might want to offload sampling to a task queue like Celery. - Monitoring: Log prediction latencies and input distributions. Bayesian models can degrade gracefully (wider credible intervals as the model becomes less confident), but you still want to know when that’s happening.
Conclusion
Serving a Bayesian model doesn’t need to be complicated. The core idea is simple: fit the model offline, save the trace, load it at API startup, and use the posterior samples to generate predictions on demand. FastAPI handles the HTTP plumbing, Pydantic validates the inputs, and PyMC gives us the full posterior predictive distribution. The result is an API that returns not just forecasts but honest uncertainty estimates, which is arguably the whole point of going Bayesian in the first place.