Serving Predictions Without Heavy Infrastructure
“Model serving” means taking a trained model artifact and exposing it in a repeatable way so other code (or other systems) can request predictions. In beginner-friendly setups, you can start with (1) a pure Python script that runs inference in batches and writes results to disk, then (2) a minimal HTTP service that accepts requests and returns predictions. In both cases, the most important idea is to define a stable input/output contract: what fields you accept, their dtypes, shapes, and how you map them to the model’s expected tensors.
1) Pure Python Inference Script (Batch Predictions)
This pattern is useful for offline scoring: nightly jobs, backfills, or scoring a CSV file. It avoids networking and keeps debugging simple.
Step-by-step: Load a SavedModel and run batch inference
- Read input rows (CSV/JSONL/Parquet) and convert them into numeric arrays.
- Validate and preprocess inputs to match the model’s expected dtype and shape.
- Run inference in batches to balance speed and memory.
- Postprocess outputs (e.g., argmax, thresholding, formatting) and write results.
Example: Batch scoring script
import argparse, csv, json, logging, sys, time
from typing import List, Dict
import numpy as np
import tensorflow as tf
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s"
)
def read_csv_rows(path: str) -> List[Dict[str, str]]:
with open(path, newline="") as f:
reader = csv.DictReader(f)
return list(reader)
def validate_and_preprocess(rows: List[Dict[str, str]]):
"""Convert raw strings to model-ready tensors.
Example assumes a model that expects a dense float32 tensor of shape [batch, 3]
with features: [age, income, score].
"""
feats = []
for i, r in enumerate(rows):
try:
age = float(r["age"])
income = float(r["income"])
score = float(r["score"])
except (KeyError, ValueError) as e:
raise ValueError(f"Bad row at index {i}: {r}") from e
feats.append([age, income, score])
x = np.asarray(feats, dtype=np.float32) # shape: [batch, 3]
return tf.convert_to_tensor(x)
def batch_iter(items, batch_size: int):
for i in range(0, len(items), batch_size):
yield items[i:i + batch_size]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model_dir", required=True)
ap.add_argument("--input_csv", required=True)
ap.add_argument("--output_jsonl", required=True)
ap.add_argument("--batch_size", type=int, default=256)
args = ap.parse_args()
# Load SavedModel
model = tf.saved_model.load(args.model_dir)
# Most SavedModels expose a callable signature. Commonly: "serving_default".
infer = model.signatures.get("serving_default")
if infer is None:
raise RuntimeError("No 'serving_default' signature found in SavedModel")
# Inspect inputs/outputs to ensure your preprocessing matches.
logging.info("Signature inputs: %s", infer.structured_input_signature)
logging.info("Signature outputs: %s", infer.structured_outputs)
rows = read_csv_rows(args.input_csv)
t0 = time.time()
n = 0
with open(args.output_jsonl, "w") as out:
for chunk in batch_iter(rows, args.batch_size):
x = validate_and_preprocess(chunk)
# Map to the signature's expected input name.
# If the signature expects a single tensor, it often has a key like "inputs".
# Adjust "inputs" to match what you see in structured_input_signature.
preds = infer(inputs=x)
# Many models return a dict of outputs; pick the relevant key.
# Adjust "outputs" to match your model.
y = preds.get("outputs")
if y is None:
# Fallback: if there's only one output, take it.
y = list(preds.values())[0]
y_np = y.numpy()
for r, p in zip(chunk, y_np):
record = {"input": r, "prediction": p.tolist()}
out.write(json.dumps(record) + "\n")
n += 1
dt = time.time() - t0
logging.info("Scored %d rows in %.3fs (%.1f rows/s)", n, dt, n / max(dt, 1e-9))
if __name__ == "__main__":
try:
main()
except Exception:
logging.exception("Inference job failed")
sys.exit(1)Practical notes for this pattern
Signature-driven preprocessing: Use
infer.structured_input_signatureto discover input names, dtypes, and shapes. Your script should adapt to those, not guess.Batching: Larger batches typically improve throughput (predictions per second) but can increase latency per item and memory usage.
Determinism: For offline jobs, you usually want stable outputs. Avoid random augmentations and ensure preprocessing is consistent.
Continue in our app.
You can listen to the audiobook with the screen off, receive a free certificate for this course, and also have access to 5,000 other free online courses.
Or continue reading below...Download the app
Latency vs. Throughput: Choosing Batch Size
Batching is the main lever you control in both scripts and services:
Latency is how long a single request takes end-to-end. Small batches (even size 1) minimize waiting time but may underutilize the CPU/GPU.
Throughput is how many predictions you can produce per second. Larger batches often increase throughput because the model runs fewer times and uses vectorized compute.
Rule of thumb
Offline scoring: prefer larger batches (e.g., 256–4096) as long as memory is safe.
Online HTTP: start small (e.g., 1–32) and only add batching if you can tolerate extra waiting time or implement a short batching window.
Simple micro-benchmark snippet
import time
import numpy as np
import tensorflow as tf
model = tf.saved_model.load("/path/to/saved_model")
infer = model.signatures["serving_default"]
for bs in [1, 8, 32, 128, 512]:
x = tf.constant(np.random.rand(bs, 3).astype("float32"))
# Warm-up
for _ in range(10):
_ = infer(inputs=x)
t0 = time.time()
iters = 200
for _ in range(iters):
_ = infer(inputs=x)
dt = time.time() - t0
preds_per_sec = (bs * iters) / max(dt, 1e-9)
avg_ms_per_batch = (dt / iters) * 1000
print(f"batch={bs:4d} avg_ms/batch={avg_ms_per_batch:7.2f} preds/s={preds_per_sec:10.1f}")Interpretation: if avg_ms/batch grows slowly while preds/s grows a lot, batching helps. For HTTP, remember that user-perceived latency includes request parsing, validation, and network overhead.
2) Minimal HTTP Service Pattern
An HTTP service wraps inference behind an endpoint like POST /predict. The core flow is the same as the batch script, but you add request parsing, validation, and safe error handling.
Service responsibilities
Request validation: ensure required fields exist, types are correct, and sizes are bounded.
Preprocessing: convert JSON to tensors with correct dtype/shape.
Model call: call the SavedModel signature.
Postprocessing: convert tensors to JSON-friendly types, apply thresholds/labels.
Observability: structured logs, timing, and safe error messages.
Example: Minimal Flask service
from flask import Flask, request, jsonify
import logging, time
import numpy as np
import tensorflow as tf
app = Flask(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
MODEL_DIR = "/path/to/saved_model"
model = tf.saved_model.load(MODEL_DIR)
infer = model.signatures.get("serving_default")
if infer is None:
raise RuntimeError("No 'serving_default' signature found")
# Adjust these to your contract
MAX_BATCH = 256
N_FEATURES = 3
def parse_and_validate_json(payload):
"""Expected payload:
{
"instances": [
{"age": 30, "income": 50000, "score": 0.2},
...
]
}
"""
if not isinstance(payload, dict):
return None, ("Body must be a JSON object", 400)
instances = payload.get("instances")
if not isinstance(instances, list) or len(instances) == 0:
return None, ("'instances' must be a non-empty list", 400)
if len(instances) > MAX_BATCH:
return None, (f"Batch too large (max {MAX_BATCH})", 413)
feats = []
for i, inst in enumerate(instances):
if not isinstance(inst, dict):
return None, (f"Instance {i} must be an object", 400)
try:
age = float(inst["age"])
income = float(inst["income"])
score = float(inst["score"])
except (KeyError, ValueError, TypeError):
return None, (f"Instance {i} must include numeric age, income, score", 400)
feats.append([age, income, score])
x = np.asarray(feats, dtype=np.float32)
if x.ndim != 2 or x.shape[1] != N_FEATURES:
return None, (f"Expected shape [batch, {N_FEATURES}]", 400)
return tf.convert_to_tensor(x), None
def postprocess(pred_dict):
# Pick output tensor
y = pred_dict.get("outputs")
if y is None:
y = list(pred_dict.values())[0]
y_np = y.numpy()
# Example postprocessing: if output is probabilities, return both raw and argmax
if y_np.ndim == 2 and y_np.shape[1] > 1:
labels = np.argmax(y_np, axis=1)
return {"predictions": y_np.tolist(), "classes": labels.tolist()}
return {"predictions": y_np.tolist()}
@app.post("/predict")
def predict():
t0 = time.time()
payload = request.get_json(silent=True)
if payload is None:
return jsonify({"error": "Invalid or missing JSON"}), 400
x, err = parse_and_validate_json(payload)
if err:
msg, code = err
return jsonify({"error": msg}), code
try:
# Adjust input key name to match your signature
pred = infer(inputs=x)
result = postprocess(pred)
ms = (time.time() - t0) * 1000
logging.info("predict ok batch=%d ms=%.2f", int(x.shape[0]), ms)
return jsonify(result)
except Exception:
# Log full stack trace internally, return generic message to clients
logging.exception("predict failed")
return jsonify({"error": "Prediction failed"}), 500Testing the endpoint
curl -X POST http://localhost:5000/predict \
-H 'Content-Type: application/json' \
-d '{"instances":[{"age":30,"income":50000,"score":0.2},{"age":45,"income":120000,"score":0.9}]}'Common improvements (still minimal)
Load model once: load at process start, not per request.
Bound request size: cap batch size and consider maximum JSON body size to prevent memory spikes.
Timeouts: enforce server timeouts so slow requests do not pile up.
Thread/process model: for CPU inference, multiple worker processes can increase throughput; for GPU, fewer workers may be better to avoid contention.
3) Input/Output Contracts: Schemas, Dtypes, Shapes
A serving contract is an agreement between clients and your model service about how inputs are represented and what outputs mean. Most serving bugs are contract mismatches: wrong field names, wrong dtype, wrong shape, or inconsistent preprocessing.
What to specify in a contract
Schema: required fields, optional fields, defaults, allowed ranges.
Dtypes: float32 vs float64, int32 vs int64, strings vs numbers.
Shapes: per-instance shape and batch shape (e.g.,
[batch, 3]).Semantics: what each feature means and units (e.g., income in USD, normalized score in [0,1]).
Output meaning: probabilities vs logits, class indices, thresholds.
Example contract: Dense numeric features
Request JSON
{
"instances": [
{"age": 30, "income": 50000, "score": 0.2},
{"age": 45, "income": 120000, "score": 0.9}
]
}Model input tensor
# dtype: float32
# shape: [batch, 3]
[[30.0, 50000.0, 0.2],
[45.0, 120000.0, 0.9]]Response JSON (example)
{
"predictions": [[0.1, 0.9], [0.8, 0.2]],
"classes": [1, 0]
}Example contract: Multiple named inputs
Some models expect a dictionary of tensors (for example, one tensor per feature). In that case, your preprocessing should build a dict keyed by the signature’s input names.
# Example request
{
"instances": [
{"age": 30, "country": "US"},
{"age": 45, "country": "DE"}
]
}
# Example model call (names must match the SavedModel signature)
inputs = {
"age": tf.constant([30.0, 45.0], dtype=tf.float32),
"country": tf.constant(["US", "DE"], dtype=tf.string)
}
outputs = infer(**inputs)How to discover the contract from the SavedModel
Even if you already know your features, always confirm what the exported signature expects.
import tensorflow as tf
m = tf.saved_model.load("/path/to/saved_model")
fn = m.signatures["serving_default"]
print("Inputs:")
for name, spec in fn.structured_input_signature[1].items():
print(name, spec)
print("Outputs:")
for name, spec in fn.structured_outputs.items():
print(name, spec)Use this output to align your HTTP JSON keys, preprocessing, and the keyword arguments you pass to infer.
Safe Error Handling and Logging
Serving code runs continuously and receives unpredictable inputs. Your goal is to fail safely: return clear client errors for bad requests, avoid leaking sensitive details, and keep enough logs to debug issues.
Error handling guidelines
Differentiate 4xx vs 5xx: validation failures should be 400/413; unexpected exceptions should be 500.
Don’t leak internals: return generic 500 messages; log stack traces server-side.
Validate early: check required fields, types, and bounds before allocating large arrays.
Guard against NaNs/Infs: reject or sanitize non-finite numeric inputs if your model can’t handle them.
Logging guidelines
Log timings and batch sizes: request duration and batch size are key for performance debugging.
Use request IDs: attach a correlation ID (from a header or generated) to connect client reports to server logs.
Avoid logging raw payloads: inputs may contain personal or sensitive data; log summaries (counts, shapes, min/max) instead.
Example: Logging a safe summary of inputs
import numpy as np
def safe_input_summary(x_np: np.ndarray):
return {
"shape": list(x_np.shape),
"dtype": str(x_np.dtype),
"min": float(np.min(x_np)),
"max": float(np.max(x_np)),
"finite": bool(np.isfinite(x_np).all()),
}