home / skills / zhanghandong / rust-skills / domain-ml

domain-ml skill

/skills/domain-ml

This skill helps you build efficient Rust ML apps by applying domain constraints for memory, GPU, and model portability.

npx playbooks add skill zhanghandong/rust-skills --skill domain-ml

Review the files below or copy the command above to add this skill to your agents.

Files (1)
SKILL.md
4.6 KB
---
name: domain-ml
description: "Use when building ML/AI apps in Rust. Keywords: machine learning, ML, AI, tensor, model, inference, neural network, deep learning, training, prediction, ndarray, tch-rs, burn, candle, 机器学习, 人工智能, 模型推理"
user-invocable: false
---

# Machine Learning Domain

> **Layer 3: Domain Constraints**

## Domain Constraints → Design Implications

| Domain Rule | Design Constraint | Rust Implication |
|-------------|-------------------|------------------|
| Large data | Efficient memory | Zero-copy, streaming |
| GPU acceleration | CUDA/Metal support | candle, tch-rs |
| Model portability | Standard formats | ONNX |
| Batch processing | Throughput over latency | Batched inference |
| Numerical precision | Float handling | ndarray, careful f32/f64 |
| Reproducibility | Deterministic | Seeded random, versioning |

---

## Critical Constraints

### Memory Efficiency

```
RULE: Avoid copying large tensors
WHY: Memory bandwidth is bottleneck
RUST: References, views, in-place ops
```

### GPU Utilization

```
RULE: Batch operations for GPU efficiency
WHY: GPU overhead per kernel launch
RUST: Batch sizes, async data loading
```

### Model Portability

```
RULE: Use standard model formats
WHY: Train in Python, deploy in Rust
RUST: ONNX via tract or candle
```

---

## Trace Down ↓

From constraints to design (Layer 2):

```
"Need efficient data pipelines"
    ↓ m10-performance: Streaming, batching
    ↓ polars: Lazy evaluation

"Need GPU inference"
    ↓ m07-concurrency: Async data loading
    ↓ candle/tch-rs: CUDA backend

"Need model loading"
    ↓ m12-lifecycle: Lazy init, caching
    ↓ tract: ONNX runtime
```

---

## Use Case → Framework

| Use Case | Recommended | Why |
|----------|-------------|-----|
| Inference only | tract (ONNX) | Lightweight, portable |
| Training + inference | candle, burn | Pure Rust, GPU |
| PyTorch models | tch-rs | Direct bindings |
| Data pipelines | polars | Fast, lazy eval |

## Key Crates

| Purpose | Crate |
|---------|-------|
| Tensors | ndarray |
| ONNX inference | tract |
| ML framework | candle, burn |
| PyTorch bindings | tch-rs |
| Data processing | polars |
| Embeddings | fastembed |

## Design Patterns

| Pattern | Purpose | Implementation |
|---------|---------|----------------|
| Model loading | Once, reuse | `OnceLock<Model>` |
| Batching | Throughput | Collect then process |
| Streaming | Large data | Iterator-based |
| GPU async | Parallelism | Data loading parallel to compute |

## Code Pattern: Inference Server

```rust
use std::sync::OnceLock;
use tract_onnx::prelude::*;

static MODEL: OnceLock<SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>> = OnceLock::new();

fn get_model() -> &'static SimplePlan<...> {
    MODEL.get_or_init(|| {
        tract_onnx::onnx()
            .model_for_path("model.onnx")
            .unwrap()
            .into_optimized()
            .unwrap()
            .into_runnable()
            .unwrap()
    })
}

async fn predict(input: Vec<f32>) -> anyhow::Result<Vec<f32>> {
    let model = get_model();
    let input = tract_ndarray::arr1(&input).into_shape((1, input.len()))?;
    let result = model.run(tvec!(input.into()))?;
    Ok(result[0].to_array_view::<f32>()?.iter().copied().collect())
}
```

## Code Pattern: Batched Inference

```rust
async fn batch_predict(inputs: Vec<Vec<f32>>, batch_size: usize) -> Vec<Vec<f32>> {
    let mut results = Vec::with_capacity(inputs.len());

    for batch in inputs.chunks(batch_size) {
        // Stack inputs into batch tensor
        let batch_tensor = stack_inputs(batch);

        // Run inference on batch
        let batch_output = model.run(batch_tensor).await;

        // Unstack results
        results.extend(unstack_outputs(batch_output));
    }

    results
}
```

---

## Common Mistakes

| Mistake | Domain Violation | Fix |
|---------|-----------------|-----|
| Clone tensors | Memory waste | Use views |
| Single inference | GPU underutilized | Batch processing |
| Load model per request | Slow | Singleton pattern |
| Sync data loading | GPU idle | Async pipeline |

---

## Trace to Layer 1

| Constraint | Layer 2 Pattern | Layer 1 Implementation |
|------------|-----------------|------------------------|
| Memory efficiency | Zero-copy | ndarray views |
| Model singleton | Lazy init | OnceLock<Model> |
| Batch processing | Chunked iteration | chunks() + parallel |
| GPU async | Concurrent loading | tokio::spawn + GPU |

---

## Related Skills

| When | See |
|------|-----|
| Performance | m10-performance |
| Lazy initialization | m12-lifecycle |
| Async patterns | m07-concurrency |
| Memory efficiency | m01-ownership |

Overview

This skill provides practical guidance for building ML/AI applications in Rust with a focus on memory, GPU utilization, model portability, and high-throughput inference. It summarizes recommended crates, design patterns, and concrete code patterns for singleton model loading, batched inference, and streaming pipelines. Use it to make Rust ML systems efficient, reproducible, and production-ready.

How this skill works

The skill inspects domain constraints (large data, GPU, portability, precision, reproducibility) and maps them to Rust design implications and patterns. It recommends specific crates (tract, candle, tch-rs, burn, ndarray, polars) and shows how to implement singleton model loading, batch processing, and async data pipelines. It also enumerates common mistakes and fixes to avoid memory waste and GPU underutilization.

When to use it

  • Building an inference service that loads ONNX models in Rust
  • Training or running models on GPU with pure-Rust frameworks (candle, burn)
  • Creating high-throughput batch inference pipelines
  • Processing large datasets with zero-copy or streaming approaches
  • Porting Python-trained models to Rust for production deployment

Best practices

  • Avoid copying large tensors: use references, views, and in-place operations
  • Load models once and reuse: use OnceLock or lazy initialization to prevent per-request loads
  • Batch GPU work: collect inputs into batches to maximize GPU throughput
  • Use standard model formats (ONNX) for portability between ecosystems
  • Stream large data and use iterator-based pipelines or polars for lazy evaluation
  • Seed RNGs and pin dependency versions for deterministic behavior

Example use cases

  • Lightweight ONNX inference server using tract and OnceLock for model reuse
  • Batch prediction pipeline: collect records, stack into batch tensors, run model, then unstack outputs
  • GPU training/inference using candle or burn with async data loaders to hide IO latency
  • Deploying PyTorch-trained models via tch-rs for direct interoperability with existing PyTorch artifacts
  • Large-scale feature processing with polars and ndarray views to avoid copies

FAQ

How do I avoid memory copies when preparing tensors?

Use ndarray views and slice references, build tensors in-place, and prefer stack/concatenate operations that accept views. Avoid cloning large buffers.

Which crate should I use for ONNX models?

Use tract for lightweight, portable ONNX inference in Rust. For full Rust training frameworks, prefer candle or burn; use tch-rs when you need direct PyTorch compatibility.