home / skills / amnadtaowsoam / cerebraskills / mlflow-patterns

This skill helps you manage end-to-end MLflow workflows for tracking experiments, registry, and deployment to ensure reproducibility and governance.

npx playbooks add skill amnadtaowsoam/cerebraskills --skill mlflow-patterns

Review the files below or copy the command above to add this skill to your agents.

Files (1)
SKILL.md
10.3 KB
---
name: MLflow Patterns
description: ML experiment tracking, model registry, and deployment with MLflow for reproducible machine learning workflows.
---

# MLflow Patterns

## Overview

MLflow is an open-source platform for managing the complete ML lifecycle, including experiment tracking, model packaging, model registry, and deployment. It enables data science teams to collaborate and deploy models reproducibly.

## Why This Matters

- **Reproducibility**: Track experiments and reproduce results
- **Collaboration**: Share experiments and models across teams
- **Deployment**: Package and deploy models consistently
- **Governance**: Model versioning and approval workflow

---

## Core Concepts

### 1. Experiment Tracking

```python
import mlflow
from mlflow.tracking import MlflowClient

# Set tracking URI
mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("customer-churn-prediction")

# Start run with auto-logging
mlflow.sklearn.autolog()

with mlflow.start_run(run_name="xgboost-v1") as run:
    # Log parameters
    mlflow.log_params({
        "learning_rate": 0.1,
        "max_depth": 6,
        "n_estimators": 100,
        "subsample": 0.8,
    })
    
    # Train model
    model = XGBClassifier(
        learning_rate=0.1,
        max_depth=6,
        n_estimators=100,
        subsample=0.8,
    )
    model.fit(X_train, y_train)
    
    # Log metrics
    y_pred = model.predict(X_test)
    mlflow.log_metrics({
        "accuracy": accuracy_score(y_test, y_pred),
        "precision": precision_score(y_test, y_pred),
        "recall": recall_score(y_test, y_pred),
        "f1": f1_score(y_test, y_pred),
        "auc_roc": roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]),
    })
    
    # Log artifacts
    mlflow.log_artifact("feature_importance.png")
    mlflow.log_artifact("confusion_matrix.png")
    
    # Log model
    mlflow.sklearn.log_model(
        model,
        artifact_path="model",
        registered_model_name="churn-prediction-model",
    )
    
    # Log dataset info
    mlflow.log_input(
        mlflow.data.from_pandas(X_train, source="s3://data/train.parquet"),
        context="training"
    )
    
    print(f"Run ID: {run.info.run_id}")
```

### 2. Custom Model Wrapper

```python
import mlflow.pyfunc

class ChurnModelWrapper(mlflow.pyfunc.PythonModel):
    """Custom model wrapper with preprocessing"""
    
    def load_context(self, context):
        """Load model and artifacts"""
        import joblib
        self.model = joblib.load(context.artifacts["model"])
        self.preprocessor = joblib.load(context.artifacts["preprocessor"])
        self.feature_names = context.artifacts["feature_names"]
    
    def predict(self, context, model_input):
        """Predict with preprocessing"""
        # Validate input
        if not all(col in model_input.columns for col in self.feature_names):
            raise ValueError(f"Missing required features: {self.feature_names}")
        
        # Preprocess
        processed = self.preprocessor.transform(model_input[self.feature_names])
        
        # Predict with probability
        predictions = self.model.predict_proba(processed)[:,1]
        
        return pd.DataFrame({
            "churn_probability": predictions,
            "churn_prediction": (predictions > 0.5).astype(int),
        })

# Log custom model
with mlflow.start_run():
    artifacts = {
        "model": "model.joblib",
        "preprocessor": "preprocessor.joblib",
        "feature_names": "features.json",
    }
    
    mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=ChurnModelWrapper(),
        artifacts=artifacts,
        conda_env={
            "dependencies": [
                "python=3.10",
                "scikit-learn=1.3.0",
                "xgboost=2.0.0",
                "pandas=2.0.0",
            ]
        },
        signature=mlflow.models.infer_signature(X_test, predictions),
        input_example=X_test.head(5),
    )
```

### 3. Model Registry

```python
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Register model from run
model_uri = f"runs:/{run_id}/model"
model_version = mlflow.register_model(model_uri, "churn-prediction-model")

# Add description and tags
client.update_model_version(
    name="churn-prediction-model",
    version=model_version.version,
    description="XGBoost model trained on Q4 2024 data"
)

client.set_model_version_tag(
    name="churn-prediction-model",
    version=model_version.version,
    key="validation_status",
    value="pending"
)

# Transition to staging (after validation)
client.transition_model_version_stage(
    name="churn-prediction-model",
    version=model_version.version,
    stage="Staging",
    archive_existing_versions=False
)

# Promote to production (after approval)
client.transition_model_version_stage(
    name="churn-prediction-model",
    version=model_version.version,
    stage="Production",
    archive_existing_versions=True  # Archive old production version
)

# Load production model
model = mlflow.pyfunc.load_model("models:/churn-prediction-model/Production")
predictions = model.predict(new_data)
```

### 4. Model Validation Pipeline

```python
# validation/validate_model.py
import mlflow
from mlflow.tracking import MlflowClient

def validate_model(model_name: str, version: str) -> bool:
    """Validate model before promotion"""
    
    client = MlflowClient()
    model_uri = f"models:/{model_name}/{version}"
    
    # Load model
    model = mlflow.pyfunc.load_model(model_uri)
    
    # Load validation dataset
    val_data = pd.read_parquet("s3://data/validation.parquet")
    X_val, y_val = val_data.drop("target", axis=1), val_data["target"]
    
    # Run predictions
    predictions = model.predict(X_val)
    
    # Calculate metrics
    metrics = {
        "val_accuracy": accuracy_score(y_val, predictions["churn_prediction"]),
        "val_auc": roc_auc_score(y_val, predictions["churn_probability"]),
    }
    
    # Get production model metrics (if exists)
    try:
        prod_model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production")
        prod_predictions = prod_model.predict(X_val)
        prod_metrics = {
            "prod_accuracy": accuracy_score(y_val, prod_predictions["churn_prediction"]),
            "prod_auc": roc_auc_score(y_val, prod_predictions["churn_probability"]),
        }
    except:
        prod_metrics = {"prod_accuracy": 0, "prod_auc": 0}
    
    # Validation rules
    validations = [
        ("accuracy_threshold", metrics["val_accuracy"] >= 0.85),
        ("auc_threshold", metrics["val_auc"] >= 0.80),
        ("accuracy_improvement", metrics["val_accuracy"] >= prod_metrics["prod_accuracy"]),
        ("auc_improvement", metrics["val_auc"] >= prod_metrics["prod_auc"] - 0.01),  # Allow 1% drop
    ]
    
    # Log validation results
    with mlflow.start_run(run_name=f"validation-{model_name}-v{version}"):
        mlflow.log_metrics(metrics)
        mlflow.log_metrics(prod_metrics)
        
        for name, passed in validations:
            mlflow.log_metric(f"validation_{name}", int(passed))
    
    # Update model tags
    all_passed = all(passed for _, passed in validations)
    client.set_model_version_tag(
        name=model_name,
        version=version,
        key="validation_status",
        value="passed" if all_passed else "failed"
    )
    
    return all_passed
```

### 5. Model Serving

```python
# serve/model_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import mlflow

app = FastAPI()

# Load model at startup
MODEL_NAME = "churn-prediction-model"
MODEL_STAGE = "Production"
model = None

@app.on_event("startup")
async def load_model():
    global model
    model = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}/{MODEL_STAGE}")

class PredictionRequest(BaseModel):
    features: dict

class PredictionResponse(BaseModel):
    churn_probability: float
    churn_prediction: int
    model_version: str

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    try:
        input_df = pd.DataFrame([request.features])
        predictions = model.predict(input_df)
        
        return PredictionResponse(
            churn_probability=float(predictions["churn_probability"].iloc[0]),
            churn_prediction=int(predictions["churn_prediction"].iloc[0]),
            model_version=model.metadata.run_id,
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get("/health")
async def health():
    return {"status": "healthy", "model_loaded": model is not None}

# Or use MLflow's built-in serving
# mlflow models serve -m "models:/churn-prediction-model/Production" -p 5001
```

## Quick Start

1. **Install MLflow:**
   ```bash
   pip install mlflow
   ```

2. **Start tracking server:**
   ```bash
   mlflow server --backend-store-uri sqlite:///mlflow.db \
                 --default-artifact-root s3://mlflow-artifacts \
                 --host 0.0.0.0
   ```

3. **Set tracking URI in code:**
   ```python
   mlflow.set_tracking_uri("http://localhost:5000")
   ```

4. **Run experiment:**
   ```python
   with mlflow.start_run():
       mlflow.log_param("param", value)
       mlflow.log_metric("metric", value)
       mlflow.sklearn.log_model(model, "model")
   ```

5. **View in UI:** Open http://localhost:5000

## Production Checklist

- [ ] Tracking server with persistent backend
- [ ] Artifact storage (S3/GCS/Azure Blob)
- [ ] Authentication enabled
- [ ] Model signature defined
- [ ] Input examples logged
- [ ] Conda/pip environment specified
- [ ] Validation pipeline configured
- [ ] Model approval workflow
- [ ] Monitoring for model drift

## Anti-patterns

1. **No Experiment Naming**: Use meaningful experiment/run names
2. **Skipping Signatures**: Always define model signatures
3. **Manual Promotion**: Use validation pipeline for stage transitions
4. **Missing Environment**: Always specify dependencies

## Integration Points

- **Storage**: S3, GCS, Azure Blob, HDFS
- **Databases**: PostgreSQL, MySQL for backend store
- **Orchestration**: Airflow, Prefect, Dagster
- **Serving**: SageMaker, Kubernetes, Azure ML

## Further Reading

- [MLflow Documentation](https://mlflow.org/docs/latest/index.html)
- [MLflow Model Registry](https://mlflow.org/docs/latest/model-registry.html)
- [MLflow Recipes](https://mlflow.org/docs/latest/recipes.html)

Overview

This skill shows patterns for using MLflow to track experiments, register models, validate versions, and serve production-ready models. It focuses on reproducible ML workflows: logging runs, packaging custom models, enforcing validation rules, and promoting models through staging to production. The content is practical and targeted at teams building governed model delivery pipelines.

How this skill works

The skill inspects common MLflow flows and provides concrete code patterns: experiment tracking with parameters, metrics, artifacts, and datasets; custom pyfunc model wrappers for preprocessing and prediction; model registry operations for versioning and stage transitions; automated validation pipelines that compare candidate vs production; and serving examples using FastAPI or MLflow’s built-in server. Each pattern logs metadata and artifacts to enable reproducibility and traceability.

When to use it

  • When you need reproducible experiment tracking and artifact management
  • When you want a model registry workflow with staged promotion and version control
  • When deploying models that require consistent runtime environments and signatures
  • When building automated validation gates before promoting models to production
  • When integrating ML lifecycle with storage, orchestration, or serving platforms

Best practices

  • Set a central tracking URI and meaningful experiment/run names for discoverability
  • Always log model signature, input examples, and environment (conda/pip) for reproducible serving
  • Wrap preprocessing and model logic into a pyfunc PythonModel to ensure consistent inference behavior
  • Automate validation and promotion with explicit, logged metrics and allow small tolerances when comparing to production
  • Configure persistent backend and artifact storage, and enable authentication and monitoring for production deployments

Example use cases

  • Customer churn model: log experiments, register versions, validate against production, and serve via FastAPI or mlflow models serve
  • A/B model promotion: run automated validation pipeline and promote the better model to Staging/Production with registry tags
  • Preprocessing encapsulation: package complex preprocessing + model in a pyfunc wrapper so predictions in staging match training
  • End-to-end CI: integrate validation script into CI/CD to gate registry transitions and update model tags automatically

FAQ

How do I handle missing features at inference?

Use a pyfunc wrapper that validates input columns and either raises a clear error or applies default imputation before prediction.

When should I archive previous production versions?

Archive older production versions when promoting a new production model to maintain a single active production candidate while preserving history for rollbacks.