home / skills / doanchienthangdev / omgkit / ml-serving-optimization
This skill helps you optimize ML serving by applying batching, caching, model compilation, and latency-reduction techniques for production systems.
npx playbooks add skill doanchienthangdev/omgkit --skill ml-serving-optimizationReview the files below or copy the command above to add this skill to your agents.
---
name: ml-serving-optimization
description: ML serving optimization techniques including batching, caching, model compilation, and latency reduction for production ML systems.
---
# ML Serving Optimization
Optimizing ML model inference for production.
## Inference Pipeline
```
┌─────────────────────────────────────────────────────────────┐
│ INFERENCE OPTIMIZATION STACK │
├─────────────────────────────────────────────────────────────┤
│ │
│ REQUEST PREPROCESSING INFERENCE POSTPROCESS │
│ ──────── ───────────── ───────── ─────────── │
│ Batching Vectorization Compiled Response │
│ Queuing Caching Quantized Caching │
│ Load balance Async I/O Parallelized Streaming │
│ │
│ LATENCY BREAKDOWN (typical): │
│ ├── Network: 1-5ms │
│ ├── Preprocessing: 2-10ms │
│ ├── Model inference: 5-100ms │
│ └── Postprocessing: 1-5ms │
│ │
└─────────────────────────────────────────────────────────────┘
```
## Dynamic Batching
```python
import asyncio
from collections import deque
import time
class DynamicBatcher:
def __init__(self, model, max_batch_size=32, max_wait_ms=10):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.queue = deque()
self.lock = asyncio.Lock()
async def predict(self, input_data):
future = asyncio.Future()
async with self.lock:
self.queue.append((input_data, future))
if len(self.queue) >= self.max_batch_size:
await self._process_batch()
# Wait for result with timeout
await asyncio.wait_for(future, timeout=self.max_wait_ms / 1000 + 1)
return future.result()
async def _process_batch(self):
if not self.queue:
return
batch_items = []
while self.queue and len(batch_items) < self.max_batch_size:
batch_items.append(self.queue.popleft())
inputs = torch.stack([item[0] for item in batch_items])
with torch.no_grad():
outputs = self.model(inputs)
for i, (_, future) in enumerate(batch_items):
future.set_result(outputs[i])
async def batch_loop(self):
while True:
await asyncio.sleep(self.max_wait_ms / 1000)
async with self.lock:
if self.queue:
await self._process_batch()
```
## Model Compilation
### TorchScript
```python
import torch
# Tracing
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model_traced.pt")
# Scripting (for control flow)
@torch.jit.script
def forward_with_control(x, threshold: float = 0.5):
output = model(x)
if output.max() > threshold:
return output
return torch.zeros_like(output)
# Optimize for inference
traced_model = torch.jit.optimize_for_inference(traced_model)
```
### torch.compile (PyTorch 2.0+)
```python
import torch
# Default compilation
compiled_model = torch.compile(model)
# With options
compiled_model = torch.compile(
model,
mode="reduce-overhead", # or "max-autotune"
fullgraph=True,
dynamic=False
)
# Inference-optimized
compiled_model = torch.compile(
model,
mode="reduce-overhead",
backend="inductor"
)
```
### TensorRT
```python
import torch_tensorrt
# Compile to TensorRT
trt_model = torch_tensorrt.compile(
model,
inputs=[torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[8, 3, 224, 224],
max_shape=[32, 3, 224, 224],
dtype=torch.float16
)],
enabled_precisions={torch.float16},
workspace_size=1 << 30
)
# Save and load
torch.jit.save(trt_model, "model_trt.ts")
loaded = torch.jit.load("model_trt.ts")
```
### ONNX Runtime
```python
import onnxruntime as ort
# Export to ONNX
torch.onnx.export(
model, example_input, "model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}}
)
# Create optimized session
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 4
# GPU execution
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
session = ort.InferenceSession("model.onnx", session_options, providers=providers)
# Inference
outputs = session.run(None, {'input': input_data.numpy()})
```
## Caching Strategies
```python
from functools import lru_cache
import hashlib
import redis
import pickle
class InferenceCache:
def __init__(self, redis_client, ttl=3600):
self.redis = redis_client
self.ttl = ttl
def _hash_input(self, input_data):
return hashlib.sha256(input_data.tobytes()).hexdigest()
def get(self, input_data):
key = self._hash_input(input_data)
cached = self.redis.get(key)
if cached:
return pickle.loads(cached)
return None
def set(self, input_data, output):
key = self._hash_input(input_data)
self.redis.setex(key, self.ttl, pickle.dumps(output))
# Embedding cache for similar inputs
class EmbeddingCache:
def __init__(self, threshold=0.95):
self.embeddings = []
self.outputs = []
self.threshold = threshold
def find_similar(self, embedding):
if not self.embeddings:
return None
similarities = torch.cosine_similarity(
embedding.unsqueeze(0),
torch.stack(self.embeddings)
)
max_sim, idx = similarities.max(0)
if max_sim > self.threshold:
return self.outputs[idx]
return None
def add(self, embedding, output):
self.embeddings.append(embedding)
self.outputs.append(output)
```
## Async Inference
```python
import asyncio
from concurrent.futures import ThreadPoolExecutor
class AsyncInferenceService:
def __init__(self, model, num_workers=4):
self.model = model
self.executor = ThreadPoolExecutor(max_workers=num_workers)
def _sync_predict(self, input_data):
with torch.no_grad():
return self.model(input_data)
async def predict(self, input_data):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self._sync_predict,
input_data
)
async def predict_batch(self, inputs):
tasks = [self.predict(inp) for inp in inputs]
return await asyncio.gather(*tasks)
# CUDA streams for parallel inference
class StreamedInference:
def __init__(self, model, num_streams=4):
self.model = model
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
def predict_parallel(self, inputs):
outputs = []
for i, inp in enumerate(inputs):
stream = self.streams[i % len(self.streams)]
with torch.cuda.stream(stream):
outputs.append(self.model(inp))
torch.cuda.synchronize()
return outputs
```
## Latency Profiling
```python
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, List
@dataclass
class LatencyStats:
mean: float
p50: float
p95: float
p99: float
max: float
class LatencyProfiler:
def __init__(self):
self.timings: Dict[str, List[float]] = {}
@contextmanager
def measure(self, name: str):
start = time.perf_counter()
yield
elapsed = (time.perf_counter() - start) * 1000 # ms
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(elapsed)
def stats(self, name: str) -> LatencyStats:
times = sorted(self.timings[name])
n = len(times)
return LatencyStats(
mean=sum(times) / n,
p50=times[n // 2],
p95=times[int(n * 0.95)],
p99=times[int(n * 0.99)],
max=times[-1]
)
def report(self):
for name in self.timings:
s = self.stats(name)
print(f"{name}: mean={s.mean:.2f}ms p95={s.p95:.2f}ms p99={s.p99:.2f}ms")
# Usage
profiler = LatencyProfiler()
with profiler.measure("preprocess"):
preprocessed = preprocess(data)
with profiler.measure("inference"):
output = model(preprocessed)
with profiler.measure("postprocess"):
result = postprocess(output)
profiler.report()
```
## KV Cache for Transformers
```python
class KVCacheAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x, past_kv=None):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
# Attention computation
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
return self.out(out.reshape(B, T, C)), (k, v)
```
## Commands
- `/omgoptim:profile` - Profile latency
- `/omgdeploy:serve` - Optimized serving
- `/omgoptim:quantize` - Model quantization
## Best Practices
1. Measure before optimizing
2. Use dynamic batching for throughput
3. Compile models for production
4. Cache repeated computations
5. Profile end-to-end latency
This skill covers practical techniques to optimize ML model serving in production, focusing on batching, caching, compilation, async inference, and latency profiling. It distills actionable patterns and code-level strategies to reduce tail latency and improve throughput for JavaScript/Python-based stacks. The goal is measurable end-to-end inference improvements that integrate into existing serving infrastructure.
The skill inspects the inference pipeline and recommends optimizations at request handling, preprocessing, model inference, and postprocessing stages. It provides patterns for dynamic batching, model compilation (TorchScript, torch.compile, TensorRT, ONNX Runtime), caching (input/embedding/kv caches), asynchronous and parallel inference, and latency profiling. Each technique includes when to apply it and trade-offs for latency, throughput, and resource use.
Will batching always improve throughput?
Batching usually increases throughput but can increase latency if batch windows are too long; use dynamic batching with strict max_wait to limit added latency.
When should I use model compilation vs quantization?
Compile for lower overhead and runtime optimizations; quantize to reduce memory and compute cost. Combine both when accuracy remains acceptable.
How do I avoid stale cached outputs?
Use TTLs, deterministic input hashing, and cache invalidation policies; for embeddings consider similarity thresholds rather than exact matches.