The gap between a Jupyter notebook and a production ML system is vast. After building platforms that serve ML predictions to millions of users daily, I’ve learned that the ML model is often the easiest part. The hard problems are infrastructure, data pipelines, monitoring, and maintaining model quality over time. This post shares battle-tested patterns for production ML systems.
The Production ML Architecture
A production ML platform requires much more than just model serving:
Data Sources → Feature Pipeline → Feature Store → Model Serving → Application
↓
Model Training
↓
Model Registry
↓
Monitoring
Let’s break down each component.
Feature Pipelines: The Foundation
Features are the inputs to your ML models. Managing them at scale is critical.
Online vs Offline Features
class FeaturePipeline:
"""
Dual pipeline: batch for training, real-time for serving
"""
def compute_offline_features(self, user_ids: List[str],
start_date: datetime,
end_date: datetime) -> DataFrame:
"""
Batch computation for training data
Can use complex aggregations, joins
"""
# Use Spark or similar for large-scale processing
user_activity = spark.sql(f"""
SELECT
user_id,
COUNT(*) as activity_count,
AVG(session_duration) as avg_session_duration,
MAX(timestamp) as last_activity,
COUNT(DISTINCT DATE(timestamp)) as active_days
FROM user_activities
WHERE timestamp BETWEEN '{start_date}' AND '{end_date}'
GROUP BY user_id
""")
return user_activity.toPandas()
def compute_online_features(self, user_id: str) -> Dict[str, Any]:
"""
Real-time computation for inference
Must be fast (<100ms), use caching
"""
# Check cache first
cached = redis.get(f"features:user:{user_id}")
if cached:
return json.loads(cached)
# Compute from recent data
recent_activities = self.get_recent_activities(user_id, hours=24)
features = {
'activity_count': len(recent_activities),
'avg_session_duration': np.mean([a.duration for a in recent_activities]),
'last_activity': max(a.timestamp for a in recent_activities).isoformat(),
'active_days': len(set(a.timestamp.date() for a in recent_activities))
}
# Cache for 15 minutes
redis.setex(
f"features:user:{user_id}",
900,
json.dumps(features)
)
return features
Feature Consistency is Critical
The same feature must be computed identically for training and serving:
class FeatureDefinition:
"""
Single source of truth for feature computation
"""
@staticmethod
def user_activity_score(activities: List[Activity]) -> float:
"""
Used in both offline and online pipelines
"""
if not activities:
return 0.0
# Consistent calculation
weights = {'login': 1.0, 'purchase': 5.0, 'share': 3.0}
score = sum(weights.get(a.type, 0.5) for a in activities)
# Apply same transformations
return np.log1p(score)
# Use in offline pipeline
def compute_training_features(user_history: DataFrame) -> DataFrame:
user_history['activity_score'] = user_history['activities'].apply(
FeatureDefinition.user_activity_score
)
return user_history
# Use in online pipeline
def compute_serving_features(user_id: str) -> Dict:
activities = get_recent_activities(user_id)
return {
'activity_score': FeatureDefinition.user_activity_score(activities)
}
Feature Store Implementation
A feature store provides versioned, reusable features:
from dataclasses import dataclass
from typing import List, Optional
from datetime import datetime
@dataclass
class FeatureMetadata:
name: str
version: str
feature_type: str
description: str
created_at: datetime
tags: List[str]
class FeatureStore:
"""
Central repository for ML features
"""
def __init__(self, online_store, offline_store):
self.online_store = online_store # Redis/DynamoDB
self.offline_store = offline_store # S3/BigQuery
def register_feature(self, metadata: FeatureMetadata):
"""Register a new feature"""
self.metadata_db.insert(metadata)
def write_online_features(self, entity_id: str,
features: Dict[str, Any],
feature_set: str):
"""Write features for online serving"""
key = f"{feature_set}:{entity_id}"
self.online_store.hmset(key, features)
self.online_store.expire(key, 3600) # 1 hour TTL
def get_online_features(self, entity_id: str,
feature_set: str,
feature_names: List[str]) -> Dict[str, Any]:
"""Retrieve features for inference"""
key = f"{feature_set}:{entity_id}"
features = self.online_store.hmget(key, feature_names)
return dict(zip(feature_names, features))
def write_offline_features(self, df: DataFrame,
feature_set: str,
timestamp: datetime):
"""Write features for training (batch)"""
partition = timestamp.strftime("%Y-%m-%d")
path = f"s3://features/{feature_set}/dt={partition}/"
df.write.parquet(path, mode='overwrite')
def get_offline_features(self, entity_ids: List[str],
feature_set: str,
start_date: datetime,
end_date: datetime) -> DataFrame:
"""Retrieve features for training"""
path = f"s3://features/{feature_set}/"
return spark.read.parquet(path) \
.filter(col('dt').between(start_date, end_date)) \
.filter(col('entity_id').isin(entity_ids))
Model Serving at Scale
Model Registry and Versioning
class ModelRegistry:
"""
Centralized model versioning and metadata
"""
def register_model(self, model_name: str, model_path: str,
metadata: Dict) -> str:
"""Register a new model version"""
version = self.generate_version()
model_info = {
'name': model_name,
'version': version,
'path': model_path,
'framework': metadata.get('framework'),
'metrics': metadata.get('metrics'),
'training_data': metadata.get('training_data'),
'registered_at': datetime.now().isoformat(),
'registered_by': metadata.get('user')
}
self.db.insert('models', model_info)
return version
def promote_model(self, model_name: str, version: str,
stage: str):
"""
Promote model to staging/production
"""
# Validate model before promotion
self.run_validation_tests(model_name, version)
# Update stage
self.db.update('models',
{'name': model_name, 'version': version},
{'stage': stage, 'promoted_at': datetime.now()}
)
# Notify deployment system
self.trigger_deployment(model_name, version, stage)
Model Serving Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import joblib
from typing import Dict, List
app = FastAPI()
class PredictionRequest(BaseModel):
entity_ids: List[str]
features: List[Dict[str, float]]
class PredictionResponse(BaseModel):
predictions: List[float]
model_version: str
latency_ms: float
class ModelServer:
def __init__(self):
self.models = {}
self.feature_store = FeatureStore()
self.metrics = MetricsCollector()
def load_model(self, model_name: str, version: str):
"""Load model into memory"""
model_path = f"s3://models/{model_name}/{version}/model.pkl"
model = joblib.load(model_path)
self.models[f"{model_name}:{version}"] = model
@app.post("/predict")
async def predict(self, request: PredictionRequest) -> PredictionResponse:
start_time = time.time()
try:
# Get model
model = self.models.get('current_model')
# Feature retrieval
features = []
for entity_id in request.entity_ids:
entity_features = self.feature_store.get_online_features(
entity_id,
'user_features',
['activity_score', 'recency', 'frequency']
)
features.append(entity_features)
# Prepare feature matrix
X = np.array([[f['activity_score'], f['recency'], f['frequency']]
for f in features])
# Predict
predictions = model.predict_proba(X)[:, 1].tolist()
latency_ms = (time.time() - start_time) * 1000
# Record metrics
self.metrics.record_prediction(
model_version=model.version,
latency_ms=latency_ms,
batch_size=len(request.entity_ids)
)
return PredictionResponse(
predictions=predictions,
model_version=model.version,
latency_ms=latency_ms
)
except Exception as e:
self.metrics.record_error(error_type=type(e).__name__)
raise HTTPException(status_code=500, detail=str(e))
Batch Prediction Pipeline
For non-real-time predictions:
class BatchPredictionPipeline:
"""
Generate predictions in batch (e.g., daily)
"""
def run_batch_predictions(self, date: datetime):
"""
Compute predictions for all users
Store in database/cache for quick lookup
"""
# Get all active users
user_ids = self.get_active_users(date)
# Get features in batch
features_df = self.feature_store.get_offline_features(
user_ids,
'user_features',
date - timedelta(days=30),
date
)
# Load model
model = self.model_registry.get_production_model('churn_predictor')
# Predict in batches
batch_size = 10000
predictions = []
for i in range(0, len(features_df), batch_size):
batch = features_df[i:i + batch_size]
X = batch[['activity_score', 'recency', 'frequency']].values
pred = model.predict_proba(X)[:, 1]
predictions.extend(pred)
# Store predictions
results_df = pd.DataFrame({
'user_id': user_ids,
'churn_probability': predictions,
'prediction_date': date,
'model_version': model.version
})
# Write to database
self.write_predictions(results_df)
# Cache top predictions for API
high_risk = results_df[results_df['churn_probability'] > 0.7]
for _, row in high_risk.iterrows():
self.redis.setex(
f"prediction:churn:{row['user_id']}",
86400, # 24 hours
row['churn_probability']
)
Model Monitoring and Observability
Prediction Quality Monitoring
class ModelMonitor:
"""
Monitor model performance in production
"""
def log_prediction(self, entity_id: str, features: Dict,
prediction: float, model_version: str):
"""
Log every prediction for later analysis
"""
log_entry = {
'timestamp': datetime.now().isoformat(),
'entity_id': entity_id,
'features': features,
'prediction': prediction,
'model_version': model_version
}
# Write to data warehouse
self.prediction_log.append(log_entry)
def log_actual(self, entity_id: str, actual_value: float):
"""
Log actual outcomes when available
"""
# Match with predictions
predictions = self.prediction_log.get_recent(entity_id)
for pred in predictions:
# Compute error
error = abs(pred['prediction'] - actual_value)
# Store for metric calculation
self.metrics.record_prediction_error(
model_version=pred['model_version'],
error=error,
actual=actual_value,
predicted=pred['prediction']
)
def check_data_drift(self):
"""
Detect feature distribution changes
"""
# Get recent feature statistics
recent_stats = self.compute_feature_stats(days=7)
baseline_stats = self.get_baseline_stats()
for feature, recent in recent_stats.items():
baseline = baseline_stats[feature]
# KL divergence for distribution shift
kl_div = self.compute_kl_divergence(recent, baseline)
if kl_div > 0.1: # Threshold
self.alert_data_drift(feature, kl_div)
def check_model_performance(self):
"""
Track model metrics over time
"""
# Get recent predictions with actuals
recent_data = self.get_recent_predictions_with_actuals(days=7)
from sklearn.metrics import roc_auc_score, precision_score
auc = roc_auc_score(recent_data['actual'],
recent_data['prediction'])
precision = precision_score(
recent_data['actual'] > 0.5,
recent_data['prediction'] > 0.5
)
# Compare to baseline
if auc < self.baseline_auc * 0.95: # 5% degradation
self.alert_model_degradation(
metric='AUC',
current=auc,
baseline=self.baseline_auc
)
# Record metrics
self.metrics_db.insert({
'date': datetime.now().date(),
'auc': auc,
'precision': precision,
'model_version': self.current_version
})
Feature Monitoring
class FeatureMonitor:
"""
Monitor feature quality and freshness
"""
def check_feature_freshness(self):
"""
Ensure features are being updated
"""
for feature_set in self.feature_sets:
last_update = self.get_last_update_time(feature_set)
age_hours = (datetime.now() - last_update).seconds / 3600
if age_hours > 2: # Threshold
self.alert_stale_features(feature_set, age_hours)
def check_feature_quality(self):
"""
Monitor for data quality issues
"""
for feature in self.features:
recent_values = self.get_recent_values(feature)
# Check for nulls
null_rate = sum(v is None for v in recent_values) / len(recent_values)
if null_rate > 0.05: # 5% threshold
self.alert_high_null_rate(feature, null_rate)
# Check for outliers
if self.is_numeric(feature):
z_scores = np.abs(stats.zscore(recent_values))
outlier_rate = sum(z > 3 for z in z_scores) / len(z_scores)
if outlier_rate > 0.01:
self.alert_high_outlier_rate(feature, outlier_rate)
A/B Testing ML Models
class ModelABTest:
"""
Safely test new models against current production
"""
def __init__(self):
self.control_model = None # Current production
self.treatment_model = None # New candidate
self.traffic_split = 0.95 # 95% control, 5% treatment
def predict(self, entity_id: str, features: Dict) -> float:
"""
Route traffic between models
"""
# Deterministic assignment based on entity_id
assignment = self.get_assignment(entity_id)
if assignment == 'treatment':
model = self.treatment_model
variant = 'treatment'
else:
model = self.control_model
variant = 'control'
# Make prediction
prediction = model.predict(features)
# Log for analysis
self.log_ab_prediction(
entity_id=entity_id,
variant=variant,
prediction=prediction,
features=features
)
return prediction
def get_assignment(self, entity_id: str) -> str:
"""
Consistent assignment based on hash
"""
hash_val = int(hashlib.md5(entity_id.encode()).hexdigest(), 16)
bucket = (hash_val % 100) / 100.0
return 'treatment' if bucket < (1 - self.traffic_split) else 'control'
def analyze_results(self):
"""
Compare model performance between variants
"""
control_metrics = self.compute_metrics('control')
treatment_metrics = self.compute_metrics('treatment')
# Statistical significance test
from scipy import stats
t_stat, p_value = stats.ttest_ind(
control_metrics['errors'],
treatment_metrics['errors']
)
if p_value < 0.05 and treatment_metrics['mean_error'] < control_metrics['mean_error']:
print("Treatment model is significantly better!")
self.promote_treatment_model()
else:
print("No significant improvement detected")
Training Pipeline Automation
class AutomatedTrainingPipeline:
"""
Automated model retraining
"""
def should_retrain(self) -> bool:
"""
Determine if retraining is needed
"""
# Check time since last training
days_since_training = (datetime.now() - self.last_training_date).days
if days_since_training > 7:
return True
# Check performance degradation
current_performance = self.get_current_performance()
if current_performance < self.baseline_performance * 0.95:
return True
# Check data drift
if self.monitor.detect_data_drift():
return True
return False
def retrain_model(self):
"""
Automated retraining workflow
"""
# Get fresh training data
end_date = datetime.now()
start_date = end_date - timedelta(days=90)
training_data = self.feature_store.get_offline_features(
entity_ids=self.get_training_entity_ids(),
feature_set='user_features',
start_date=start_date,
end_date=end_date
)
# Train model
from sklearn.ensemble import GradientBoostingClassifier
X = training_data[self.feature_columns].values
y = training_data['label'].values
model = GradientBoostingClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=5
)
model.fit(X, y)
# Validate
val_score = self.validate_model(model)
if val_score > self.minimum_acceptable_score:
# Register new model
version = self.model_registry.register_model(
'churn_predictor',
model,
metadata={
'validation_score': val_score,
'training_date': datetime.now(),
'training_samples': len(X)
}
)
# Deploy to staging for A/B test
self.deploy_to_staging(version)
else:
self.alert_training_failed(val_score)
Key Takeaways
- Feature consistency is critical: Use the same code for training and serving
- Monitor everything: Models degrade over time without monitoring
- Automate retraining: Don’t wait for manual interventions
- Use A/B testing: Never deploy untested models to 100% traffic
- Feature stores reduce duplication: Centralize feature definitions
- Plan for failure: Models will fail, have fallbacks
- Version everything: Models, features, data
Production ML is 10% modeling and 90% engineering. Build robust infrastructure first, then iterate on models.