The gap between a Jupyter notebook and a production ML system is vast. After building platforms that serve ML predictions to millions of users daily, I’ve learned that the ML model is often the easiest part. The hard problems are infrastructure, data pipelines, monitoring, and maintaining model quality over time. This post shares battle-tested patterns for production ML systems.

The Production ML Architecture

A production ML platform requires much more than just model serving:

Data Sources → Feature Pipeline → Feature Store → Model Serving → Application

                                  Model Training

                                  Model Registry

                                   Monitoring

Let’s break down each component.

Feature Pipelines: The Foundation

Features are the inputs to your ML models. Managing them at scale is critical.

Online vs Offline Features

class FeaturePipeline:
    """
    Dual pipeline: batch for training, real-time for serving
    """

    def compute_offline_features(self, user_ids: List[str],
                                 start_date: datetime,
                                 end_date: datetime) -> DataFrame:
        """
        Batch computation for training data
        Can use complex aggregations, joins
        """
        # Use Spark or similar for large-scale processing
        user_activity = spark.sql(f"""
            SELECT
                user_id,
                COUNT(*) as activity_count,
                AVG(session_duration) as avg_session_duration,
                MAX(timestamp) as last_activity,
                COUNT(DISTINCT DATE(timestamp)) as active_days
            FROM user_activities
            WHERE timestamp BETWEEN '{start_date}' AND '{end_date}'
            GROUP BY user_id
        """)

        return user_activity.toPandas()

    def compute_online_features(self, user_id: str) -> Dict[str, Any]:
        """
        Real-time computation for inference
        Must be fast (<100ms), use caching
        """
        # Check cache first
        cached = redis.get(f"features:user:{user_id}")
        if cached:
            return json.loads(cached)

        # Compute from recent data
        recent_activities = self.get_recent_activities(user_id, hours=24)

        features = {
            'activity_count': len(recent_activities),
            'avg_session_duration': np.mean([a.duration for a in recent_activities]),
            'last_activity': max(a.timestamp for a in recent_activities).isoformat(),
            'active_days': len(set(a.timestamp.date() for a in recent_activities))
        }

        # Cache for 15 minutes
        redis.setex(
            f"features:user:{user_id}",
            900,
            json.dumps(features)
        )

        return features

Feature Consistency is Critical

The same feature must be computed identically for training and serving:

class FeatureDefinition:
    """
    Single source of truth for feature computation
    """

    @staticmethod
    def user_activity_score(activities: List[Activity]) -> float:
        """
        Used in both offline and online pipelines
        """
        if not activities:
            return 0.0

        # Consistent calculation
        weights = {'login': 1.0, 'purchase': 5.0, 'share': 3.0}
        score = sum(weights.get(a.type, 0.5) for a in activities)

        # Apply same transformations
        return np.log1p(score)

# Use in offline pipeline
def compute_training_features(user_history: DataFrame) -> DataFrame:
    user_history['activity_score'] = user_history['activities'].apply(
        FeatureDefinition.user_activity_score
    )
    return user_history

# Use in online pipeline
def compute_serving_features(user_id: str) -> Dict:
    activities = get_recent_activities(user_id)
    return {
        'activity_score': FeatureDefinition.user_activity_score(activities)
    }

Feature Store Implementation

A feature store provides versioned, reusable features:

from dataclasses import dataclass
from typing import List, Optional
from datetime import datetime

@dataclass
class FeatureMetadata:
    name: str
    version: str
    feature_type: str
    description: str
    created_at: datetime
    tags: List[str]

class FeatureStore:
    """
    Central repository for ML features
    """

    def __init__(self, online_store, offline_store):
        self.online_store = online_store  # Redis/DynamoDB
        self.offline_store = offline_store  # S3/BigQuery

    def register_feature(self, metadata: FeatureMetadata):
        """Register a new feature"""
        self.metadata_db.insert(metadata)

    def write_online_features(self, entity_id: str,
                             features: Dict[str, Any],
                             feature_set: str):
        """Write features for online serving"""
        key = f"{feature_set}:{entity_id}"
        self.online_store.hmset(key, features)
        self.online_store.expire(key, 3600)  # 1 hour TTL

    def get_online_features(self, entity_id: str,
                           feature_set: str,
                           feature_names: List[str]) -> Dict[str, Any]:
        """Retrieve features for inference"""
        key = f"{feature_set}:{entity_id}"
        features = self.online_store.hmget(key, feature_names)

        return dict(zip(feature_names, features))

    def write_offline_features(self, df: DataFrame,
                               feature_set: str,
                               timestamp: datetime):
        """Write features for training (batch)"""
        partition = timestamp.strftime("%Y-%m-%d")
        path = f"s3://features/{feature_set}/dt={partition}/"

        df.write.parquet(path, mode='overwrite')

    def get_offline_features(self, entity_ids: List[str],
                            feature_set: str,
                            start_date: datetime,
                            end_date: datetime) -> DataFrame:
        """Retrieve features for training"""
        path = f"s3://features/{feature_set}/"

        return spark.read.parquet(path) \
            .filter(col('dt').between(start_date, end_date)) \
            .filter(col('entity_id').isin(entity_ids))

Model Serving at Scale

Model Registry and Versioning

class ModelRegistry:
    """
    Centralized model versioning and metadata
    """

    def register_model(self, model_name: str, model_path: str,
                      metadata: Dict) -> str:
        """Register a new model version"""
        version = self.generate_version()

        model_info = {
            'name': model_name,
            'version': version,
            'path': model_path,
            'framework': metadata.get('framework'),
            'metrics': metadata.get('metrics'),
            'training_data': metadata.get('training_data'),
            'registered_at': datetime.now().isoformat(),
            'registered_by': metadata.get('user')
        }

        self.db.insert('models', model_info)

        return version

    def promote_model(self, model_name: str, version: str,
                     stage: str):
        """
        Promote model to staging/production
        """
        # Validate model before promotion
        self.run_validation_tests(model_name, version)

        # Update stage
        self.db.update('models',
            {'name': model_name, 'version': version},
            {'stage': stage, 'promoted_at': datetime.now()}
        )

        # Notify deployment system
        self.trigger_deployment(model_name, version, stage)

Model Serving Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import joblib
from typing import Dict, List

app = FastAPI()

class PredictionRequest(BaseModel):
    entity_ids: List[str]
    features: List[Dict[str, float]]

class PredictionResponse(BaseModel):
    predictions: List[float]
    model_version: str
    latency_ms: float

class ModelServer:
    def __init__(self):
        self.models = {}
        self.feature_store = FeatureStore()
        self.metrics = MetricsCollector()

    def load_model(self, model_name: str, version: str):
        """Load model into memory"""
        model_path = f"s3://models/{model_name}/{version}/model.pkl"
        model = joblib.load(model_path)
        self.models[f"{model_name}:{version}"] = model

    @app.post("/predict")
    async def predict(self, request: PredictionRequest) -> PredictionResponse:
        start_time = time.time()

        try:
            # Get model
            model = self.models.get('current_model')

            # Feature retrieval
            features = []
            for entity_id in request.entity_ids:
                entity_features = self.feature_store.get_online_features(
                    entity_id,
                    'user_features',
                    ['activity_score', 'recency', 'frequency']
                )
                features.append(entity_features)

            # Prepare feature matrix
            X = np.array([[f['activity_score'], f['recency'], f['frequency']]
                         for f in features])

            # Predict
            predictions = model.predict_proba(X)[:, 1].tolist()

            latency_ms = (time.time() - start_time) * 1000

            # Record metrics
            self.metrics.record_prediction(
                model_version=model.version,
                latency_ms=latency_ms,
                batch_size=len(request.entity_ids)
            )

            return PredictionResponse(
                predictions=predictions,
                model_version=model.version,
                latency_ms=latency_ms
            )

        except Exception as e:
            self.metrics.record_error(error_type=type(e).__name__)
            raise HTTPException(status_code=500, detail=str(e))

Batch Prediction Pipeline

For non-real-time predictions:

class BatchPredictionPipeline:
    """
    Generate predictions in batch (e.g., daily)
    """

    def run_batch_predictions(self, date: datetime):
        """
        Compute predictions for all users
        Store in database/cache for quick lookup
        """
        # Get all active users
        user_ids = self.get_active_users(date)

        # Get features in batch
        features_df = self.feature_store.get_offline_features(
            user_ids,
            'user_features',
            date - timedelta(days=30),
            date
        )

        # Load model
        model = self.model_registry.get_production_model('churn_predictor')

        # Predict in batches
        batch_size = 10000
        predictions = []

        for i in range(0, len(features_df), batch_size):
            batch = features_df[i:i + batch_size]
            X = batch[['activity_score', 'recency', 'frequency']].values
            pred = model.predict_proba(X)[:, 1]
            predictions.extend(pred)

        # Store predictions
        results_df = pd.DataFrame({
            'user_id': user_ids,
            'churn_probability': predictions,
            'prediction_date': date,
            'model_version': model.version
        })

        # Write to database
        self.write_predictions(results_df)

        # Cache top predictions for API
        high_risk = results_df[results_df['churn_probability'] > 0.7]
        for _, row in high_risk.iterrows():
            self.redis.setex(
                f"prediction:churn:{row['user_id']}",
                86400,  # 24 hours
                row['churn_probability']
            )

Model Monitoring and Observability

Prediction Quality Monitoring

class ModelMonitor:
    """
    Monitor model performance in production
    """

    def log_prediction(self, entity_id: str, features: Dict,
                      prediction: float, model_version: str):
        """
        Log every prediction for later analysis
        """
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'entity_id': entity_id,
            'features': features,
            'prediction': prediction,
            'model_version': model_version
        }

        # Write to data warehouse
        self.prediction_log.append(log_entry)

    def log_actual(self, entity_id: str, actual_value: float):
        """
        Log actual outcomes when available
        """
        # Match with predictions
        predictions = self.prediction_log.get_recent(entity_id)

        for pred in predictions:
            # Compute error
            error = abs(pred['prediction'] - actual_value)

            # Store for metric calculation
            self.metrics.record_prediction_error(
                model_version=pred['model_version'],
                error=error,
                actual=actual_value,
                predicted=pred['prediction']
            )

    def check_data_drift(self):
        """
        Detect feature distribution changes
        """
        # Get recent feature statistics
        recent_stats = self.compute_feature_stats(days=7)
        baseline_stats = self.get_baseline_stats()

        for feature, recent in recent_stats.items():
            baseline = baseline_stats[feature]

            # KL divergence for distribution shift
            kl_div = self.compute_kl_divergence(recent, baseline)

            if kl_div > 0.1:  # Threshold
                self.alert_data_drift(feature, kl_div)

    def check_model_performance(self):
        """
        Track model metrics over time
        """
        # Get recent predictions with actuals
        recent_data = self.get_recent_predictions_with_actuals(days=7)

        from sklearn.metrics import roc_auc_score, precision_score

        auc = roc_auc_score(recent_data['actual'],
                           recent_data['prediction'])

        precision = precision_score(
            recent_data['actual'] > 0.5,
            recent_data['prediction'] > 0.5
        )

        # Compare to baseline
        if auc < self.baseline_auc * 0.95:  # 5% degradation
            self.alert_model_degradation(
                metric='AUC',
                current=auc,
                baseline=self.baseline_auc
            )

        # Record metrics
        self.metrics_db.insert({
            'date': datetime.now().date(),
            'auc': auc,
            'precision': precision,
            'model_version': self.current_version
        })

Feature Monitoring

class FeatureMonitor:
    """
    Monitor feature quality and freshness
    """

    def check_feature_freshness(self):
        """
        Ensure features are being updated
        """
        for feature_set in self.feature_sets:
            last_update = self.get_last_update_time(feature_set)
            age_hours = (datetime.now() - last_update).seconds / 3600

            if age_hours > 2:  # Threshold
                self.alert_stale_features(feature_set, age_hours)

    def check_feature_quality(self):
        """
        Monitor for data quality issues
        """
        for feature in self.features:
            recent_values = self.get_recent_values(feature)

            # Check for nulls
            null_rate = sum(v is None for v in recent_values) / len(recent_values)
            if null_rate > 0.05:  # 5% threshold
                self.alert_high_null_rate(feature, null_rate)

            # Check for outliers
            if self.is_numeric(feature):
                z_scores = np.abs(stats.zscore(recent_values))
                outlier_rate = sum(z > 3 for z in z_scores) / len(z_scores)

                if outlier_rate > 0.01:
                    self.alert_high_outlier_rate(feature, outlier_rate)

A/B Testing ML Models

class ModelABTest:
    """
    Safely test new models against current production
    """

    def __init__(self):
        self.control_model = None  # Current production
        self.treatment_model = None  # New candidate
        self.traffic_split = 0.95  # 95% control, 5% treatment

    def predict(self, entity_id: str, features: Dict) -> float:
        """
        Route traffic between models
        """
        # Deterministic assignment based on entity_id
        assignment = self.get_assignment(entity_id)

        if assignment == 'treatment':
            model = self.treatment_model
            variant = 'treatment'
        else:
            model = self.control_model
            variant = 'control'

        # Make prediction
        prediction = model.predict(features)

        # Log for analysis
        self.log_ab_prediction(
            entity_id=entity_id,
            variant=variant,
            prediction=prediction,
            features=features
        )

        return prediction

    def get_assignment(self, entity_id: str) -> str:
        """
        Consistent assignment based on hash
        """
        hash_val = int(hashlib.md5(entity_id.encode()).hexdigest(), 16)
        bucket = (hash_val % 100) / 100.0

        return 'treatment' if bucket < (1 - self.traffic_split) else 'control'

    def analyze_results(self):
        """
        Compare model performance between variants
        """
        control_metrics = self.compute_metrics('control')
        treatment_metrics = self.compute_metrics('treatment')

        # Statistical significance test
        from scipy import stats

        t_stat, p_value = stats.ttest_ind(
            control_metrics['errors'],
            treatment_metrics['errors']
        )

        if p_value < 0.05 and treatment_metrics['mean_error'] < control_metrics['mean_error']:
            print("Treatment model is significantly better!")
            self.promote_treatment_model()
        else:
            print("No significant improvement detected")

Training Pipeline Automation

class AutomatedTrainingPipeline:
    """
    Automated model retraining
    """

    def should_retrain(self) -> bool:
        """
        Determine if retraining is needed
        """
        # Check time since last training
        days_since_training = (datetime.now() - self.last_training_date).days

        if days_since_training > 7:
            return True

        # Check performance degradation
        current_performance = self.get_current_performance()
        if current_performance < self.baseline_performance * 0.95:
            return True

        # Check data drift
        if self.monitor.detect_data_drift():
            return True

        return False

    def retrain_model(self):
        """
        Automated retraining workflow
        """
        # Get fresh training data
        end_date = datetime.now()
        start_date = end_date - timedelta(days=90)

        training_data = self.feature_store.get_offline_features(
            entity_ids=self.get_training_entity_ids(),
            feature_set='user_features',
            start_date=start_date,
            end_date=end_date
        )

        # Train model
        from sklearn.ensemble import GradientBoostingClassifier

        X = training_data[self.feature_columns].values
        y = training_data['label'].values

        model = GradientBoostingClassifier(
            n_estimators=100,
            learning_rate=0.1,
            max_depth=5
        )

        model.fit(X, y)

        # Validate
        val_score = self.validate_model(model)

        if val_score > self.minimum_acceptable_score:
            # Register new model
            version = self.model_registry.register_model(
                'churn_predictor',
                model,
                metadata={
                    'validation_score': val_score,
                    'training_date': datetime.now(),
                    'training_samples': len(X)
                }
            )

            # Deploy to staging for A/B test
            self.deploy_to_staging(version)
        else:
            self.alert_training_failed(val_score)

Key Takeaways

  1. Feature consistency is critical: Use the same code for training and serving
  2. Monitor everything: Models degrade over time without monitoring
  3. Automate retraining: Don’t wait for manual interventions
  4. Use A/B testing: Never deploy untested models to 100% traffic
  5. Feature stores reduce duplication: Centralize feature definitions
  6. Plan for failure: Models will fail, have fallbacks
  7. Version everything: Models, features, data

Production ML is 10% modeling and 90% engineering. Build robust infrastructure first, then iterate on models.