Agent skill
MLflow Patterns
ML experiment tracking, model registry, and deployment with MLflow for reproducible machine learning workflows.
Stars
163
Forks
31
Install this agent skill to your Project
npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/mlflow-patterns
SKILL.md
MLflow Patterns
Overview
MLflow เป็น open-source platform สำหรับ managing ML lifecycle ครอบคลุม experiment tracking, model packaging, model registry, และ deployment ช่วยให้ทีม data science ทำงานร่วมกันและ deploy models ได้อย่าง reproducible
Why This Matters
- Reproducibility: Track experiments และ reproduce results
- Collaboration: Share experiments และ models across team
- Deployment: Package และ deploy models consistently
- Governance: Model versioning และ approval workflow
Core Concepts
1. Experiment Tracking
python
import mlflow
from mlflow.tracking import MlflowClient
# Set tracking URI
mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("customer-churn-prediction")
# Start run with auto-logging
mlflow.sklearn.autolog()
with mlflow.start_run(run_name="xgboost-v1") as run:
# Log parameters
mlflow.log_params({
"learning_rate": 0.1,
"max_depth": 6,
"n_estimators": 100,
"subsample": 0.8,
})
# Train model
model = XGBClassifier(
learning_rate=0.1,
max_depth=6,
n_estimators=100,
subsample=0.8,
)
model.fit(X_train, y_train)
# Log metrics
y_pred = model.predict(X_test)
mlflow.log_metrics({
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred),
"recall": recall_score(y_test, y_pred),
"f1": f1_score(y_test, y_pred),
"auc_roc": roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]),
})
# Log artifacts
mlflow.log_artifact("feature_importance.png")
mlflow.log_artifact("confusion_matrix.png")
# Log model
mlflow.sklearn.log_model(
model,
artifact_path="model",
registered_model_name="churn-prediction-model",
)
# Log dataset info
mlflow.log_input(
mlflow.data.from_pandas(X_train, source="s3://data/train.parquet"),
context="training"
)
print(f"Run ID: {run.info.run_id}")
2. Custom Model Wrapper
python
import mlflow.pyfunc
class ChurnModelWrapper(mlflow.pyfunc.PythonModel):
"""Custom model wrapper with preprocessing"""
def load_context(self, context):
"""Load model and artifacts"""
import joblib
self.model = joblib.load(context.artifacts["model"])
self.preprocessor = joblib.load(context.artifacts["preprocessor"])
self.feature_names = context.artifacts["feature_names"]
def predict(self, context, model_input):
"""Predict with preprocessing"""
# Validate input
if not all(col in model_input.columns for col in self.feature_names):
raise ValueError(f"Missing required features: {self.feature_names}")
# Preprocess
processed = self.preprocessor.transform(model_input[self.feature_names])
# Predict with probability
predictions = self.model.predict_proba(processed)[:, 1]
return pd.DataFrame({
"churn_probability": predictions,
"churn_prediction": (predictions > 0.5).astype(int),
})
# Log custom model
with mlflow.start_run():
artifacts = {
"model": "model.joblib",
"preprocessor": "preprocessor.joblib",
"feature_names": "features.json",
}
mlflow.pyfunc.log_model(
artifact_path="model",
python_model=ChurnModelWrapper(),
artifacts=artifacts,
conda_env={
"dependencies": [
"python=3.10",
"scikit-learn=1.3.0",
"xgboost=2.0.0",
"pandas=2.0.0",
]
},
signature=mlflow.models.infer_signature(X_test, predictions),
input_example=X_test.head(5),
)
3. Model Registry
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register model from run
model_uri = f"runs:/{run_id}/model"
model_version = mlflow.register_model(model_uri, "churn-prediction-model")
# Add description and tags
client.update_model_version(
name="churn-prediction-model",
version=model_version.version,
description="XGBoost model trained on Q4 2024 data"
)
client.set_model_version_tag(
name="churn-prediction-model",
version=model_version.version,
key="validation_status",
value="pending"
)
# Transition to staging (after validation)
client.transition_model_version_stage(
name="churn-prediction-model",
version=model_version.version,
stage="Staging",
archive_existing_versions=False
)
# Promote to production (after approval)
client.transition_model_version_stage(
name="churn-prediction-model",
version=model_version.version,
stage="Production",
archive_existing_versions=True # Archive old production version
)
# Load production model
model = mlflow.pyfunc.load_model("models:/churn-prediction-model/Production")
predictions = model.predict(new_data)
4. Model Validation Pipeline
python
# validation/validate_model.py
import mlflow
from mlflow.tracking import MlflowClient
def validate_model(model_name: str, version: str) -> bool:
"""Validate model before promotion"""
client = MlflowClient()
model_uri = f"models:/{model_name}/{version}"
# Load model
model = mlflow.pyfunc.load_model(model_uri)
# Load validation dataset
val_data = pd.read_parquet("s3://data/validation.parquet")
X_val, y_val = val_data.drop("target", axis=1), val_data["target"]
# Run predictions
predictions = model.predict(X_val)
# Calculate metrics
metrics = {
"val_accuracy": accuracy_score(y_val, predictions["churn_prediction"]),
"val_auc": roc_auc_score(y_val, predictions["churn_probability"]),
}
# Get production model metrics (if exists)
try:
prod_model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production")
prod_predictions = prod_model.predict(X_val)
prod_metrics = {
"prod_accuracy": accuracy_score(y_val, prod_predictions["churn_prediction"]),
"prod_auc": roc_auc_score(y_val, prod_predictions["churn_probability"]),
}
except:
prod_metrics = {"prod_accuracy": 0, "prod_auc": 0}
# Validation rules
validations = [
("accuracy_threshold", metrics["val_accuracy"] >= 0.85),
("auc_threshold", metrics["val_auc"] >= 0.80),
("accuracy_improvement", metrics["val_accuracy"] >= prod_metrics["prod_accuracy"]),
("auc_improvement", metrics["val_auc"] >= prod_metrics["prod_auc"] - 0.01), # Allow 1% drop
]
# Log validation results
with mlflow.start_run(run_name=f"validation-{model_name}-v{version}"):
mlflow.log_metrics(metrics)
mlflow.log_metrics(prod_metrics)
for name, passed in validations:
mlflow.log_metric(f"validation_{name}", int(passed))
# Update model tags
all_passed = all(passed for _, passed in validations)
client.set_model_version_tag(
name=model_name,
version=version,
key="validation_status",
value="passed" if all_passed else "failed"
)
return all_passed
5. Model Serving
python
# serve/model_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import mlflow
app = FastAPI()
# Load model at startup
MODEL_NAME = "churn-prediction-model"
MODEL_STAGE = "Production"
model = None
@app.on_event("startup")
async def load_model():
global model
model = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}/{MODEL_STAGE}")
class PredictionRequest(BaseModel):
features: dict
class PredictionResponse(BaseModel):
churn_probability: float
churn_prediction: int
model_version: str
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
input_df = pd.DataFrame([request.features])
predictions = model.predict(input_df)
return PredictionResponse(
churn_probability=float(predictions["churn_probability"].iloc[0]),
churn_prediction=int(predictions["churn_prediction"].iloc[0]),
model_version=model.metadata.run_id,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": model is not None}
# Or use MLflow's built-in serving
# mlflow models serve -m "models:/churn-prediction-model/Production" -p 5001
Quick Start
-
Install MLflow:
bashpip install mlflow -
Start tracking server:
bashmlflow server --backend-store-uri sqlite:///mlflow.db \ --default-artifact-root s3://mlflow-artifacts \ --host 0.0.0.0 -
Set tracking URI in code:
pythonmlflow.set_tracking_uri("http://localhost:5000") -
Run experiment:
pythonwith mlflow.start_run(): mlflow.log_param("param", value) mlflow.log_metric("metric", value) mlflow.sklearn.log_model(model, "model") -
View in UI: Open http://localhost:5000
Production Checklist
- Tracking server with persistent backend
- Artifact storage (S3/GCS/Azure Blob)
- Authentication enabled
- Model signature defined
- Input examples logged
- Conda/pip environment specified
- Validation pipeline configured
- Model approval workflow
- Monitoring for model drift
Anti-patterns
- No Experiment Naming: Use meaningful experiment/run names
- Skipping Signatures: Always define model signatures
- Manual Promotion: Use validation pipeline for stage transitions
- Missing Environment: Always specify dependencies
Integration Points
- Storage: S3, GCS, Azure Blob, HDFS
- Databases: PostgreSQL, MySQL for backend store
- Orchestration: Airflow, Prefect, Dagster
- Serving: SageMaker, Kubernetes, Azure ML
Further Reading
Didn't find tool you were looking for?