home / skills / doanchienthangdev / omgkit / ml-serving-optimization

ml-serving-optimization skill

/plugin/skills/ml-systems/ml-serving-optimization

This skill helps you optimize ML serving by applying batching, caching, model compilation, and latency-reduction techniques for production systems.

npx playbooks add skill doanchienthangdev/omgkit --skill ml-serving-optimization

Review the files below or copy the command above to add this skill to your agents.

Files (1)
SKILL.md
10.4 KB
---
name: ml-serving-optimization
description: ML serving optimization techniques including batching, caching, model compilation, and latency reduction for production ML systems.
---

# ML Serving Optimization

Optimizing ML model inference for production.

## Inference Pipeline

```
┌─────────────────────────────────────────────────────────────┐
│                 INFERENCE OPTIMIZATION STACK                 │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  REQUEST       PREPROCESSING    INFERENCE     POSTPROCESS   │
│  ────────      ─────────────    ─────────     ───────────   │
│  Batching      Vectorization    Compiled      Response      │
│  Queuing       Caching          Quantized     Caching       │
│  Load balance  Async I/O        Parallelized  Streaming     │
│                                                              │
│  LATENCY BREAKDOWN (typical):                               │
│  ├── Network: 1-5ms                                         │
│  ├── Preprocessing: 2-10ms                                  │
│  ├── Model inference: 5-100ms                               │
│  └── Postprocessing: 1-5ms                                  │
│                                                              │
└─────────────────────────────────────────────────────────────┘
```

## Dynamic Batching

```python
import asyncio
from collections import deque
import time

class DynamicBatcher:
    def __init__(self, model, max_batch_size=32, max_wait_ms=10):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue = deque()
        self.lock = asyncio.Lock()

    async def predict(self, input_data):
        future = asyncio.Future()
        async with self.lock:
            self.queue.append((input_data, future))

            if len(self.queue) >= self.max_batch_size:
                await self._process_batch()

        # Wait for result with timeout
        await asyncio.wait_for(future, timeout=self.max_wait_ms / 1000 + 1)
        return future.result()

    async def _process_batch(self):
        if not self.queue:
            return

        batch_items = []
        while self.queue and len(batch_items) < self.max_batch_size:
            batch_items.append(self.queue.popleft())

        inputs = torch.stack([item[0] for item in batch_items])
        with torch.no_grad():
            outputs = self.model(inputs)

        for i, (_, future) in enumerate(batch_items):
            future.set_result(outputs[i])

    async def batch_loop(self):
        while True:
            await asyncio.sleep(self.max_wait_ms / 1000)
            async with self.lock:
                if self.queue:
                    await self._process_batch()
```

## Model Compilation

### TorchScript
```python
import torch

# Tracing
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model_traced.pt")

# Scripting (for control flow)
@torch.jit.script
def forward_with_control(x, threshold: float = 0.5):
    output = model(x)
    if output.max() > threshold:
        return output
    return torch.zeros_like(output)

# Optimize for inference
traced_model = torch.jit.optimize_for_inference(traced_model)
```

### torch.compile (PyTorch 2.0+)
```python
import torch

# Default compilation
compiled_model = torch.compile(model)

# With options
compiled_model = torch.compile(
    model,
    mode="reduce-overhead",  # or "max-autotune"
    fullgraph=True,
    dynamic=False
)

# Inference-optimized
compiled_model = torch.compile(
    model,
    mode="reduce-overhead",
    backend="inductor"
)
```

### TensorRT
```python
import torch_tensorrt

# Compile to TensorRT
trt_model = torch_tensorrt.compile(
    model,
    inputs=[torch_tensorrt.Input(
        min_shape=[1, 3, 224, 224],
        opt_shape=[8, 3, 224, 224],
        max_shape=[32, 3, 224, 224],
        dtype=torch.float16
    )],
    enabled_precisions={torch.float16},
    workspace_size=1 << 30
)

# Save and load
torch.jit.save(trt_model, "model_trt.ts")
loaded = torch.jit.load("model_trt.ts")
```

### ONNX Runtime
```python
import onnxruntime as ort

# Export to ONNX
torch.onnx.export(
    model, example_input, "model.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}}
)

# Create optimized session
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 4

# GPU execution
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
session = ort.InferenceSession("model.onnx", session_options, providers=providers)

# Inference
outputs = session.run(None, {'input': input_data.numpy()})
```

## Caching Strategies

```python
from functools import lru_cache
import hashlib
import redis
import pickle

class InferenceCache:
    def __init__(self, redis_client, ttl=3600):
        self.redis = redis_client
        self.ttl = ttl

    def _hash_input(self, input_data):
        return hashlib.sha256(input_data.tobytes()).hexdigest()

    def get(self, input_data):
        key = self._hash_input(input_data)
        cached = self.redis.get(key)
        if cached:
            return pickle.loads(cached)
        return None

    def set(self, input_data, output):
        key = self._hash_input(input_data)
        self.redis.setex(key, self.ttl, pickle.dumps(output))

# Embedding cache for similar inputs
class EmbeddingCache:
    def __init__(self, threshold=0.95):
        self.embeddings = []
        self.outputs = []
        self.threshold = threshold

    def find_similar(self, embedding):
        if not self.embeddings:
            return None

        similarities = torch.cosine_similarity(
            embedding.unsqueeze(0),
            torch.stack(self.embeddings)
        )
        max_sim, idx = similarities.max(0)

        if max_sim > self.threshold:
            return self.outputs[idx]
        return None

    def add(self, embedding, output):
        self.embeddings.append(embedding)
        self.outputs.append(output)
```

## Async Inference

```python
import asyncio
from concurrent.futures import ThreadPoolExecutor

class AsyncInferenceService:
    def __init__(self, model, num_workers=4):
        self.model = model
        self.executor = ThreadPoolExecutor(max_workers=num_workers)

    def _sync_predict(self, input_data):
        with torch.no_grad():
            return self.model(input_data)

    async def predict(self, input_data):
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            self.executor,
            self._sync_predict,
            input_data
        )

    async def predict_batch(self, inputs):
        tasks = [self.predict(inp) for inp in inputs]
        return await asyncio.gather(*tasks)

# CUDA streams for parallel inference
class StreamedInference:
    def __init__(self, model, num_streams=4):
        self.model = model
        self.streams = [torch.cuda.Stream() for _ in range(num_streams)]

    def predict_parallel(self, inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            stream = self.streams[i % len(self.streams)]
            with torch.cuda.stream(stream):
                outputs.append(self.model(inp))

        torch.cuda.synchronize()
        return outputs
```

## Latency Profiling

```python
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, List

@dataclass
class LatencyStats:
    mean: float
    p50: float
    p95: float
    p99: float
    max: float

class LatencyProfiler:
    def __init__(self):
        self.timings: Dict[str, List[float]] = {}

    @contextmanager
    def measure(self, name: str):
        start = time.perf_counter()
        yield
        elapsed = (time.perf_counter() - start) * 1000  # ms
        if name not in self.timings:
            self.timings[name] = []
        self.timings[name].append(elapsed)

    def stats(self, name: str) -> LatencyStats:
        times = sorted(self.timings[name])
        n = len(times)
        return LatencyStats(
            mean=sum(times) / n,
            p50=times[n // 2],
            p95=times[int(n * 0.95)],
            p99=times[int(n * 0.99)],
            max=times[-1]
        )

    def report(self):
        for name in self.timings:
            s = self.stats(name)
            print(f"{name}: mean={s.mean:.2f}ms p95={s.p95:.2f}ms p99={s.p99:.2f}ms")

# Usage
profiler = LatencyProfiler()

with profiler.measure("preprocess"):
    preprocessed = preprocess(data)

with profiler.measure("inference"):
    output = model(preprocessed)

with profiler.measure("postprocess"):
    result = postprocess(output)

profiler.report()
```

## KV Cache for Transformers

```python
class KVCacheAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, past_kv=None):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)

        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=1)
            v = torch.cat([past_v, v], dim=1)

        # Attention computation
        attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        return self.out(out.reshape(B, T, C)), (k, v)
```

## Commands
- `/omgoptim:profile` - Profile latency
- `/omgdeploy:serve` - Optimized serving
- `/omgoptim:quantize` - Model quantization

## Best Practices

1. Measure before optimizing
2. Use dynamic batching for throughput
3. Compile models for production
4. Cache repeated computations
5. Profile end-to-end latency

Overview

This skill covers practical techniques to optimize ML model serving in production, focusing on batching, caching, compilation, async inference, and latency profiling. It distills actionable patterns and code-level strategies to reduce tail latency and improve throughput for JavaScript/Python-based stacks. The goal is measurable end-to-end inference improvements that integrate into existing serving infrastructure.

How this skill works

The skill inspects the inference pipeline and recommends optimizations at request handling, preprocessing, model inference, and postprocessing stages. It provides patterns for dynamic batching, model compilation (TorchScript, torch.compile, TensorRT, ONNX Runtime), caching (input/embedding/kv caches), asynchronous and parallel inference, and latency profiling. Each technique includes when to apply it and trade-offs for latency, throughput, and resource use.

When to use it

  • When tail latency exceeds SLOs and profiling shows model inference or queuing hotspots
  • When throughput needs to increase without linear GPU scaling
  • When many similar or repeated requests can reuse previous outputs (caching opportunity)
  • When CPU/GPU utilization is low due to small per-request work (use batching/compilation)
  • When transformers or autoregressive models require fast token sampling (KV cache)

Best practices

  • Measure end-to-end latency and per-stage timings before making changes
  • Start with low-risk techniques: response caching and async workers, then add batching
  • Use dynamic batching with a bounded max_wait to balance latency vs throughput
  • Compile or quantize models for production; validate numeric quality after changes
  • Profile p50/p95/p99 and iterate—optimize the dominant contributor first

Example use cases

  • REST or gRPC inference front-end that aggregates requests into dynamic batches to maximize GPU utilization
  • Embedding service that uses cosine-similarity caching to return nearest cached outputs for similar inputs
  • Transformer chat service that stores KV cache per session to avoid recomputing past keys/values
  • Edge or latency-sensitive inference where torch.compile or TensorRT reduces per-request overhead
  • Asynchronous API backed by a thread or process pool to keep request threads non-blocking

FAQ

Will batching always improve throughput?

Batching usually increases throughput but can increase latency if batch windows are too long; use dynamic batching with strict max_wait to limit added latency.

When should I use model compilation vs quantization?

Compile for lower overhead and runtime optimizations; quantize to reduce memory and compute cost. Combine both when accuracy remains acceptable.

How do I avoid stale cached outputs?

Use TTLs, deterministic input hashing, and cache invalidation policies; for embeddings consider similarity thresholds rather than exact matches.