home / skills / amnadtaowsoam / cerebraskills / mlflow-patterns
This skill helps you manage end-to-end MLflow workflows for tracking experiments, registry, and deployment to ensure reproducibility and governance.
npx playbooks add skill amnadtaowsoam/cerebraskills --skill mlflow-patternsReview the files below or copy the command above to add this skill to your agents.
---
name: MLflow Patterns
description: ML experiment tracking, model registry, and deployment with MLflow for reproducible machine learning workflows.
---
# MLflow Patterns
## Overview
MLflow is an open-source platform for managing the complete ML lifecycle, including experiment tracking, model packaging, model registry, and deployment. It enables data science teams to collaborate and deploy models reproducibly.
## Why This Matters
- **Reproducibility**: Track experiments and reproduce results
- **Collaboration**: Share experiments and models across teams
- **Deployment**: Package and deploy models consistently
- **Governance**: Model versioning and approval workflow
---
## Core Concepts
### 1. Experiment Tracking
```python
import mlflow
from mlflow.tracking import MlflowClient
# Set tracking URI
mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("customer-churn-prediction")
# Start run with auto-logging
mlflow.sklearn.autolog()
with mlflow.start_run(run_name="xgboost-v1") as run:
# Log parameters
mlflow.log_params({
"learning_rate": 0.1,
"max_depth": 6,
"n_estimators": 100,
"subsample": 0.8,
})
# Train model
model = XGBClassifier(
learning_rate=0.1,
max_depth=6,
n_estimators=100,
subsample=0.8,
)
model.fit(X_train, y_train)
# Log metrics
y_pred = model.predict(X_test)
mlflow.log_metrics({
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred),
"recall": recall_score(y_test, y_pred),
"f1": f1_score(y_test, y_pred),
"auc_roc": roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]),
})
# Log artifacts
mlflow.log_artifact("feature_importance.png")
mlflow.log_artifact("confusion_matrix.png")
# Log model
mlflow.sklearn.log_model(
model,
artifact_path="model",
registered_model_name="churn-prediction-model",
)
# Log dataset info
mlflow.log_input(
mlflow.data.from_pandas(X_train, source="s3://data/train.parquet"),
context="training"
)
print(f"Run ID: {run.info.run_id}")
```
### 2. Custom Model Wrapper
```python
import mlflow.pyfunc
class ChurnModelWrapper(mlflow.pyfunc.PythonModel):
"""Custom model wrapper with preprocessing"""
def load_context(self, context):
"""Load model and artifacts"""
import joblib
self.model = joblib.load(context.artifacts["model"])
self.preprocessor = joblib.load(context.artifacts["preprocessor"])
self.feature_names = context.artifacts["feature_names"]
def predict(self, context, model_input):
"""Predict with preprocessing"""
# Validate input
if not all(col in model_input.columns for col in self.feature_names):
raise ValueError(f"Missing required features: {self.feature_names}")
# Preprocess
processed = self.preprocessor.transform(model_input[self.feature_names])
# Predict with probability
predictions = self.model.predict_proba(processed)[:,1]
return pd.DataFrame({
"churn_probability": predictions,
"churn_prediction": (predictions > 0.5).astype(int),
})
# Log custom model
with mlflow.start_run():
artifacts = {
"model": "model.joblib",
"preprocessor": "preprocessor.joblib",
"feature_names": "features.json",
}
mlflow.pyfunc.log_model(
artifact_path="model",
python_model=ChurnModelWrapper(),
artifacts=artifacts,
conda_env={
"dependencies": [
"python=3.10",
"scikit-learn=1.3.0",
"xgboost=2.0.0",
"pandas=2.0.0",
]
},
signature=mlflow.models.infer_signature(X_test, predictions),
input_example=X_test.head(5),
)
```
### 3. Model Registry
```python
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register model from run
model_uri = f"runs:/{run_id}/model"
model_version = mlflow.register_model(model_uri, "churn-prediction-model")
# Add description and tags
client.update_model_version(
name="churn-prediction-model",
version=model_version.version,
description="XGBoost model trained on Q4 2024 data"
)
client.set_model_version_tag(
name="churn-prediction-model",
version=model_version.version,
key="validation_status",
value="pending"
)
# Transition to staging (after validation)
client.transition_model_version_stage(
name="churn-prediction-model",
version=model_version.version,
stage="Staging",
archive_existing_versions=False
)
# Promote to production (after approval)
client.transition_model_version_stage(
name="churn-prediction-model",
version=model_version.version,
stage="Production",
archive_existing_versions=True # Archive old production version
)
# Load production model
model = mlflow.pyfunc.load_model("models:/churn-prediction-model/Production")
predictions = model.predict(new_data)
```
### 4. Model Validation Pipeline
```python
# validation/validate_model.py
import mlflow
from mlflow.tracking import MlflowClient
def validate_model(model_name: str, version: str) -> bool:
"""Validate model before promotion"""
client = MlflowClient()
model_uri = f"models:/{model_name}/{version}"
# Load model
model = mlflow.pyfunc.load_model(model_uri)
# Load validation dataset
val_data = pd.read_parquet("s3://data/validation.parquet")
X_val, y_val = val_data.drop("target", axis=1), val_data["target"]
# Run predictions
predictions = model.predict(X_val)
# Calculate metrics
metrics = {
"val_accuracy": accuracy_score(y_val, predictions["churn_prediction"]),
"val_auc": roc_auc_score(y_val, predictions["churn_probability"]),
}
# Get production model metrics (if exists)
try:
prod_model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production")
prod_predictions = prod_model.predict(X_val)
prod_metrics = {
"prod_accuracy": accuracy_score(y_val, prod_predictions["churn_prediction"]),
"prod_auc": roc_auc_score(y_val, prod_predictions["churn_probability"]),
}
except:
prod_metrics = {"prod_accuracy": 0, "prod_auc": 0}
# Validation rules
validations = [
("accuracy_threshold", metrics["val_accuracy"] >= 0.85),
("auc_threshold", metrics["val_auc"] >= 0.80),
("accuracy_improvement", metrics["val_accuracy"] >= prod_metrics["prod_accuracy"]),
("auc_improvement", metrics["val_auc"] >= prod_metrics["prod_auc"] - 0.01), # Allow 1% drop
]
# Log validation results
with mlflow.start_run(run_name=f"validation-{model_name}-v{version}"):
mlflow.log_metrics(metrics)
mlflow.log_metrics(prod_metrics)
for name, passed in validations:
mlflow.log_metric(f"validation_{name}", int(passed))
# Update model tags
all_passed = all(passed for _, passed in validations)
client.set_model_version_tag(
name=model_name,
version=version,
key="validation_status",
value="passed" if all_passed else "failed"
)
return all_passed
```
### 5. Model Serving
```python
# serve/model_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import mlflow
app = FastAPI()
# Load model at startup
MODEL_NAME = "churn-prediction-model"
MODEL_STAGE = "Production"
model = None
@app.on_event("startup")
async def load_model():
global model
model = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}/{MODEL_STAGE}")
class PredictionRequest(BaseModel):
features: dict
class PredictionResponse(BaseModel):
churn_probability: float
churn_prediction: int
model_version: str
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
input_df = pd.DataFrame([request.features])
predictions = model.predict(input_df)
return PredictionResponse(
churn_probability=float(predictions["churn_probability"].iloc[0]),
churn_prediction=int(predictions["churn_prediction"].iloc[0]),
model_version=model.metadata.run_id,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": model is not None}
# Or use MLflow's built-in serving
# mlflow models serve -m "models:/churn-prediction-model/Production" -p 5001
```
## Quick Start
1. **Install MLflow:**
```bash
pip install mlflow
```
2. **Start tracking server:**
```bash
mlflow server --backend-store-uri sqlite:///mlflow.db \
--default-artifact-root s3://mlflow-artifacts \
--host 0.0.0.0
```
3. **Set tracking URI in code:**
```python
mlflow.set_tracking_uri("http://localhost:5000")
```
4. **Run experiment:**
```python
with mlflow.start_run():
mlflow.log_param("param", value)
mlflow.log_metric("metric", value)
mlflow.sklearn.log_model(model, "model")
```
5. **View in UI:** Open http://localhost:5000
## Production Checklist
- [ ] Tracking server with persistent backend
- [ ] Artifact storage (S3/GCS/Azure Blob)
- [ ] Authentication enabled
- [ ] Model signature defined
- [ ] Input examples logged
- [ ] Conda/pip environment specified
- [ ] Validation pipeline configured
- [ ] Model approval workflow
- [ ] Monitoring for model drift
## Anti-patterns
1. **No Experiment Naming**: Use meaningful experiment/run names
2. **Skipping Signatures**: Always define model signatures
3. **Manual Promotion**: Use validation pipeline for stage transitions
4. **Missing Environment**: Always specify dependencies
## Integration Points
- **Storage**: S3, GCS, Azure Blob, HDFS
- **Databases**: PostgreSQL, MySQL for backend store
- **Orchestration**: Airflow, Prefect, Dagster
- **Serving**: SageMaker, Kubernetes, Azure ML
## Further Reading
- [MLflow Documentation](https://mlflow.org/docs/latest/index.html)
- [MLflow Model Registry](https://mlflow.org/docs/latest/model-registry.html)
- [MLflow Recipes](https://mlflow.org/docs/latest/recipes.html)
This skill shows patterns for using MLflow to track experiments, register models, validate versions, and serve production-ready models. It focuses on reproducible ML workflows: logging runs, packaging custom models, enforcing validation rules, and promoting models through staging to production. The content is practical and targeted at teams building governed model delivery pipelines.
The skill inspects common MLflow flows and provides concrete code patterns: experiment tracking with parameters, metrics, artifacts, and datasets; custom pyfunc model wrappers for preprocessing and prediction; model registry operations for versioning and stage transitions; automated validation pipelines that compare candidate vs production; and serving examples using FastAPI or MLflow’s built-in server. Each pattern logs metadata and artifacts to enable reproducibility and traceability.
How do I handle missing features at inference?
Use a pyfunc wrapper that validates input columns and either raises a clear error or applies default imputation before prediction.
When should I archive previous production versions?
Archive older production versions when promoting a new production model to maintain a single active production candidate while preserving history for rollbacks.