Agent skill

MLflow Patterns

ML experiment tracking, model registry, and deployment with MLflow for reproducible machine learning workflows.

Stars 163
Forks 31

Install this agent skill to your Project

npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/mlflow-patterns

SKILL.md

MLflow Patterns

Overview

MLflow เป็น open-source platform สำหรับ managing ML lifecycle ครอบคลุม experiment tracking, model packaging, model registry, และ deployment ช่วยให้ทีม data science ทำงานร่วมกันและ deploy models ได้อย่าง reproducible

Why This Matters

  • Reproducibility: Track experiments และ reproduce results
  • Collaboration: Share experiments และ models across team
  • Deployment: Package และ deploy models consistently
  • Governance: Model versioning และ approval workflow

Core Concepts

1. Experiment Tracking

python
import mlflow
from mlflow.tracking import MlflowClient

# Set tracking URI
mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("customer-churn-prediction")

# Start run with auto-logging
mlflow.sklearn.autolog()

with mlflow.start_run(run_name="xgboost-v1") as run:
    # Log parameters
    mlflow.log_params({
        "learning_rate": 0.1,
        "max_depth": 6,
        "n_estimators": 100,
        "subsample": 0.8,
    })
    
    # Train model
    model = XGBClassifier(
        learning_rate=0.1,
        max_depth=6,
        n_estimators=100,
        subsample=0.8,
    )
    model.fit(X_train, y_train)
    
    # Log metrics
    y_pred = model.predict(X_test)
    mlflow.log_metrics({
        "accuracy": accuracy_score(y_test, y_pred),
        "precision": precision_score(y_test, y_pred),
        "recall": recall_score(y_test, y_pred),
        "f1": f1_score(y_test, y_pred),
        "auc_roc": roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]),
    })
    
    # Log artifacts
    mlflow.log_artifact("feature_importance.png")
    mlflow.log_artifact("confusion_matrix.png")
    
    # Log model
    mlflow.sklearn.log_model(
        model,
        artifact_path="model",
        registered_model_name="churn-prediction-model",
    )
    
    # Log dataset info
    mlflow.log_input(
        mlflow.data.from_pandas(X_train, source="s3://data/train.parquet"),
        context="training"
    )
    
    print(f"Run ID: {run.info.run_id}")

2. Custom Model Wrapper

python
import mlflow.pyfunc

class ChurnModelWrapper(mlflow.pyfunc.PythonModel):
    """Custom model wrapper with preprocessing"""
    
    def load_context(self, context):
        """Load model and artifacts"""
        import joblib
        self.model = joblib.load(context.artifacts["model"])
        self.preprocessor = joblib.load(context.artifacts["preprocessor"])
        self.feature_names = context.artifacts["feature_names"]
    
    def predict(self, context, model_input):
        """Predict with preprocessing"""
        # Validate input
        if not all(col in model_input.columns for col in self.feature_names):
            raise ValueError(f"Missing required features: {self.feature_names}")
        
        # Preprocess
        processed = self.preprocessor.transform(model_input[self.feature_names])
        
        # Predict with probability
        predictions = self.model.predict_proba(processed)[:, 1]
        
        return pd.DataFrame({
            "churn_probability": predictions,
            "churn_prediction": (predictions > 0.5).astype(int),
        })

# Log custom model
with mlflow.start_run():
    artifacts = {
        "model": "model.joblib",
        "preprocessor": "preprocessor.joblib",
        "feature_names": "features.json",
    }
    
    mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=ChurnModelWrapper(),
        artifacts=artifacts,
        conda_env={
            "dependencies": [
                "python=3.10",
                "scikit-learn=1.3.0",
                "xgboost=2.0.0",
                "pandas=2.0.0",
            ]
        },
        signature=mlflow.models.infer_signature(X_test, predictions),
        input_example=X_test.head(5),
    )

3. Model Registry

python
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Register model from run
model_uri = f"runs:/{run_id}/model"
model_version = mlflow.register_model(model_uri, "churn-prediction-model")

# Add description and tags
client.update_model_version(
    name="churn-prediction-model",
    version=model_version.version,
    description="XGBoost model trained on Q4 2024 data"
)

client.set_model_version_tag(
    name="churn-prediction-model",
    version=model_version.version,
    key="validation_status",
    value="pending"
)

# Transition to staging (after validation)
client.transition_model_version_stage(
    name="churn-prediction-model",
    version=model_version.version,
    stage="Staging",
    archive_existing_versions=False
)

# Promote to production (after approval)
client.transition_model_version_stage(
    name="churn-prediction-model",
    version=model_version.version,
    stage="Production",
    archive_existing_versions=True  # Archive old production version
)

# Load production model
model = mlflow.pyfunc.load_model("models:/churn-prediction-model/Production")
predictions = model.predict(new_data)

4. Model Validation Pipeline

python
# validation/validate_model.py
import mlflow
from mlflow.tracking import MlflowClient

def validate_model(model_name: str, version: str) -> bool:
    """Validate model before promotion"""
    
    client = MlflowClient()
    model_uri = f"models:/{model_name}/{version}"
    
    # Load model
    model = mlflow.pyfunc.load_model(model_uri)
    
    # Load validation dataset
    val_data = pd.read_parquet("s3://data/validation.parquet")
    X_val, y_val = val_data.drop("target", axis=1), val_data["target"]
    
    # Run predictions
    predictions = model.predict(X_val)
    
    # Calculate metrics
    metrics = {
        "val_accuracy": accuracy_score(y_val, predictions["churn_prediction"]),
        "val_auc": roc_auc_score(y_val, predictions["churn_probability"]),
    }
    
    # Get production model metrics (if exists)
    try:
        prod_model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production")
        prod_predictions = prod_model.predict(X_val)
        prod_metrics = {
            "prod_accuracy": accuracy_score(y_val, prod_predictions["churn_prediction"]),
            "prod_auc": roc_auc_score(y_val, prod_predictions["churn_probability"]),
        }
    except:
        prod_metrics = {"prod_accuracy": 0, "prod_auc": 0}
    
    # Validation rules
    validations = [
        ("accuracy_threshold", metrics["val_accuracy"] >= 0.85),
        ("auc_threshold", metrics["val_auc"] >= 0.80),
        ("accuracy_improvement", metrics["val_accuracy"] >= prod_metrics["prod_accuracy"]),
        ("auc_improvement", metrics["val_auc"] >= prod_metrics["prod_auc"] - 0.01),  # Allow 1% drop
    ]
    
    # Log validation results
    with mlflow.start_run(run_name=f"validation-{model_name}-v{version}"):
        mlflow.log_metrics(metrics)
        mlflow.log_metrics(prod_metrics)
        
        for name, passed in validations:
            mlflow.log_metric(f"validation_{name}", int(passed))
    
    # Update model tags
    all_passed = all(passed for _, passed in validations)
    client.set_model_version_tag(
        name=model_name,
        version=version,
        key="validation_status",
        value="passed" if all_passed else "failed"
    )
    
    return all_passed

5. Model Serving

python
# serve/model_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import mlflow

app = FastAPI()

# Load model at startup
MODEL_NAME = "churn-prediction-model"
MODEL_STAGE = "Production"
model = None

@app.on_event("startup")
async def load_model():
    global model
    model = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}/{MODEL_STAGE}")

class PredictionRequest(BaseModel):
    features: dict

class PredictionResponse(BaseModel):
    churn_probability: float
    churn_prediction: int
    model_version: str

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    try:
        input_df = pd.DataFrame([request.features])
        predictions = model.predict(input_df)
        
        return PredictionResponse(
            churn_probability=float(predictions["churn_probability"].iloc[0]),
            churn_prediction=int(predictions["churn_prediction"].iloc[0]),
            model_version=model.metadata.run_id,
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get("/health")
async def health():
    return {"status": "healthy", "model_loaded": model is not None}

# Or use MLflow's built-in serving
# mlflow models serve -m "models:/churn-prediction-model/Production" -p 5001

Quick Start

  1. Install MLflow:

    bash
    pip install mlflow
    
  2. Start tracking server:

    bash
    mlflow server --backend-store-uri sqlite:///mlflow.db \
                  --default-artifact-root s3://mlflow-artifacts \
                  --host 0.0.0.0
    
  3. Set tracking URI in code:

    python
    mlflow.set_tracking_uri("http://localhost:5000")
    
  4. Run experiment:

    python
    with mlflow.start_run():
        mlflow.log_param("param", value)
        mlflow.log_metric("metric", value)
        mlflow.sklearn.log_model(model, "model")
    
  5. View in UI: Open http://localhost:5000

Production Checklist

  • Tracking server with persistent backend
  • Artifact storage (S3/GCS/Azure Blob)
  • Authentication enabled
  • Model signature defined
  • Input examples logged
  • Conda/pip environment specified
  • Validation pipeline configured
  • Model approval workflow
  • Monitoring for model drift

Anti-patterns

  1. No Experiment Naming: Use meaningful experiment/run names
  2. Skipping Signatures: Always define model signatures
  3. Manual Promotion: Use validation pipeline for stage transitions
  4. Missing Environment: Always specify dependencies

Integration Points

  • Storage: S3, GCS, Azure Blob, HDFS
  • Databases: PostgreSQL, MySQL for backend store
  • Orchestration: Airflow, Prefect, Dagster
  • Serving: SageMaker, Kubernetes, Azure ML

Further Reading

Didn't find tool you were looking for?

Be as detailed as possible for better results