Security analytics has evolved from rule-based detection to sophisticated ML systems that can identify novel threats in real-time. This post explores architecting ML-powered security analytics at scale.
Streaming Analytics Architecture
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
class SecurityEvent:
def __init__(self, timestamp, source_ip, event_type, features):
self.timestamp = timestamp
self.source_ip = source_ip
self.event_type = event_type
self.features = features
class ThreatDetectionPipeline:
def __init__(self, model):
self.model = model
def build_pipeline(self):
options = PipelineOptions()
with beam.Pipeline(options=options) as pipeline:
(pipeline
| 'Read Events' >> beam.io.ReadFromPubSub(topic='security-events')
| 'Parse Events' >> beam.Map(self.parse_event)
| 'Extract Features' >> beam.ParDo(FeatureExtractor())
| 'Windowing' >> beam.WindowInto(beam.window.FixedWindows(60))
| 'Aggregate' >> beam.CombinePerKey(BehavioralAggregator())
| 'ML Inference' >> beam.ParDo(MLInference(self.model))
| 'Filter Threats' >> beam.Filter(lambda x: x.risk_score > 0.8)
| 'Alert' >> beam.ParDo(AlertSender()))
class FeatureExtractor(beam.DoFn):
def process(self, event):
features = {
'hour_of_day': event.timestamp.hour,
'day_of_week': event.timestamp.weekday(),
'event_type_encoded': self.encode_event_type(event.event_type),
'source_ip_reputation': self.get_ip_reputation(event.source_ip),
}
yield (event.source_ip, (event, features))
class MLInference(beam.DoFn):
def __init__(self, model):
self.model = model
def process(self, element):
source_ip, (event, features) = element
risk_score = self.model.predict([list(features.values())])[0]
yield {
'source_ip': source_ip,
'event': event,
'risk_score': risk_score,
'features': features
}
Anomaly Detection Models
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
import numpy as np
class AnomalyDetector:
def __init__(self):
self.scaler = StandardScaler()
self.model = IsolationForest(
contamination=0.01, # Expect 1% anomalies
n_estimators=100,
max_samples=256,
random_state=42
)
self.is_fitted = False
def fit(self, normal_traffic):
"""Train on known-good traffic"""
features = self.extract_features(normal_traffic)
scaled = self.scaler.fit_transform(features)
self.model.fit(scaled)
self.is_fitted = True
def predict(self, events):
"""Detect anomalies in new events"""
features = self.extract_features(events)
scaled = self.scaler.transform(features)
scores = self.model.decision_function(scaled)
predictions = self.model.predict(scaled)
# -1 = anomaly, 1 = normal
anomalies = []
for i, pred in enumerate(predictions):
if pred == -1:
anomalies.append({
'event': events[i],
'anomaly_score': -scores[i], # Higher = more anomalous
'features': features[i]
})
return anomalies
def extract_features(self, events):
"""Extract numerical features from events"""
features = []
for event in events:
features.append([
event.packet_size,
event.packets_per_second,
event.unique_destinations,
event.failed_attempts,
self.time_since_last_event(event),
])
return np.array(features)
Behavioral Profiling
from collections import defaultdict
from datetime import datetime, timedelta
class BehavioralProfile:
def __init__(self, window_hours=24):
self.window = timedelta(hours=window_hours)
self.profiles = defaultdict(lambda: {
'events': [],
'baseline': None
})
def update(self, entity_id, event):
"""Update behavioral profile"""
profile = self.profiles[entity_id]
# Add event
profile['events'].append(event)
# Remove old events outside window
cutoff = datetime.now() - self.window
profile['events'] = [
e for e in profile['events']
if e.timestamp > cutoff
]
# Update baseline if enough data
if len(profile['events']) > 100:
profile['baseline'] = self._compute_baseline(profile['events'])
def detect_deviation(self, entity_id, new_event):
"""Detect if new event deviates from baseline"""
profile = self.profiles[entity_id]
if not profile['baseline']:
return 0.0 # No baseline yet
current_behavior = self._compute_behavior([new_event])
baseline = profile['baseline']
# Calculate deviation
deviation = self._calculate_deviation(current_behavior, baseline)
return deviation
def _compute_baseline(self, events):
"""Compute baseline behavior from historical events"""
return {
'avg_requests_per_hour': len(events) / 24,
'avg_bytes_per_request': np.mean([e.bytes for e in events]),
'unique_destinations': len(set(e.destination for e in events)),
'common_ports': Counter(e.port for e in events).most_common(5),
'time_distribution': self._compute_time_distribution(events)
}
def _calculate_deviation(self, current, baseline):
"""Calculate deviation score"""
deviations = []
# Request rate deviation
if baseline['avg_requests_per_hour'] > 0:
rate_dev = abs(current['avg_requests_per_hour'] - baseline['avg_requests_per_hour']) / baseline['avg_requests_per_hour']
deviations.append(rate_dev)
# Byte size deviation
if baseline['avg_bytes_per_request'] > 0:
bytes_dev = abs(current['avg_bytes_per_request'] - baseline['avg_bytes_per_request']) / baseline['avg_bytes_per_request']
deviations.append(bytes_dev)
return np.mean(deviations)
Real-Time Inference Optimization
// High-performance feature extraction in Rust
use ndarray::{Array1, Array2};
pub struct FastFeatureExtractor {
scaler_mean: Array1<f32>,
scaler_std: Array1<f32>,
}
impl FastFeatureExtractor {
pub fn extract_and_normalize(&self, events: &[Event]) -> Array2<f32> {
let n = events.len();
let mut features = Array2::<f32>::zeros((n, 10));
for (i, event) in events.iter().enumerate() {
features[[i, 0]] = event.packet_size as f32;
features[[i, 1]] = event.packets_per_second;
features[[i, 2]] = event.bytes_transferred as f32;
features[[i, 3]] = event.duration_ms;
features[[i, 4]] = event.unique_ports as f32;
features[[i, 5]] = event.failed_attempts as f32;
features[[i, 6]] = event.hour_of_day as f32;
features[[i, 7]] = event.day_of_week as f32;
features[[i, 8]] = if event.is_encrypted { 1.0 } else { 0.0 };
features[[i, 9]] = event.protocol_id as f32;
}
// Normalize
for i in 0..n {
for j in 0..10 {
features[[i, j]] = (features[[i, j]] - self.scaler_mean[j]) / self.scaler_std[j];
}
}
features
}
}
// Batch inference with ONNX Runtime
use ort::{Session, Value};
pub struct ONNXInferenceEngine {
session: Session,
}
impl ONNXInferenceEngine {
pub fn predict_batch(&self, features: Array2<f32>) -> Vec<f32> {
let input = Value::from_array(self.session.allocator(), &features).unwrap();
let outputs = self.session.run(vec![input]).unwrap();
let output = outputs[0].try_extract::<f32>().unwrap();
output.view().to_owned().into_raw_vec()
}
}
Threat Intelligence Integration
class ThreatIntelligence:
def __init__(self, redis_client):
self.redis = redis_client
self.update_interval = 3600 # 1 hour
async def enrich_event(self, event):
"""Enrich event with threat intelligence"""
enrichment = {}
# IP reputation
enrichment['ip_reputation'] = await self.get_ip_reputation(event.source_ip)
# Domain reputation
if event.domain:
enrichment['domain_reputation'] = await self.get_domain_reputation(event.domain)
# File hash check
if event.file_hash:
enrichment['file_reputation'] = await self.get_file_reputation(event.file_hash)
# Known attack patterns
enrichment['known_patterns'] = await self.match_attack_patterns(event)
return {**event.__dict__, **enrichment}
async def get_ip_reputation(self, ip):
"""Get IP reputation score from cache or external sources"""
# Check cache
cached = await self.redis.get(f"ip_rep:{ip}")
if cached:
return float(cached)
# Fetch from threat feed
reputation = await self._fetch_ip_reputation(ip)
# Cache for future lookups
await self.redis.setex(f"ip_rep:{ip}", 3600, reputation)
return reputation
async def _fetch_ip_reputation(self, ip):
"""Fetch from external threat intelligence sources"""
# Query AbuseIPDB, VirusTotal, etc.
# Aggregate scores
return 0.5 # Placeholder
Alert Correlation and Deduplication
class AlertCorrelator:
def __init__(self, time_window=300): # 5 minutes
self.time_window = time_window
self.recent_alerts = []
def process_alert(self, alert):
"""Correlate with recent alerts"""
# Clean old alerts
cutoff = time.time() - self.time_window
self.recent_alerts = [a for a in self.recent_alerts if a['timestamp'] > cutoff]
# Find related alerts
related = self._find_related_alerts(alert)
if related:
# Correlate into incident
incident = self._create_incident(alert, related)
return incident
else:
# New alert
self.recent_alerts.append(alert)
return alert
def _find_related_alerts(self, alert):
"""Find alerts that should be correlated"""
related = []
for existing in self.recent_alerts:
# Same source IP
if existing.get('source_ip') == alert.get('source_ip'):
related.append(existing)
continue
# Same attack pattern
if existing.get('attack_type') == alert.get('attack_type'):
if self._are_targets_related(existing.get('target'), alert.get('target')):
related.append(existing)
return related
def _create_incident(self, new_alert, related_alerts):
"""Create correlated incident"""
all_alerts = related_alerts + [new_alert]
return {
'incident_id': self._generate_incident_id(),
'severity': max(a.get('severity', 0) for a in all_alerts),
'alert_count': len(all_alerts),
'first_seen': min(a['timestamp'] for a in all_alerts),
'last_seen': max(a['timestamp'] for a in all_alerts),
'affected_assets': list(set(a.get('target') for a in all_alerts)),
'attack_type': new_alert.get('attack_type'),
'source_ips': list(set(a.get('source_ip') for a in all_alerts)),
}
Adaptive Thresholds
class AdaptiveThreshold:
def __init__(self, initial_threshold=0.8, sensitivity=0.1):
self.threshold = initial_threshold
self.sensitivity = sensitivity
self.recent_scores = []
self.false_positive_rate = 0.0
def should_alert(self, risk_score):
"""Determine if score warrants alert"""
self.recent_scores.append(risk_score)
if len(self.recent_scores) > 1000:
self.recent_scores = self.recent_scores[-1000:]
return risk_score > self.threshold
def update_from_feedback(self, was_true_positive):
"""Adjust threshold based on analyst feedback"""
if was_true_positive:
# Lower threshold slightly to catch similar events
self.threshold *= (1 - self.sensitivity)
else:
# Raise threshold to reduce false positives
self.threshold *= (1 + self.sensitivity)
self.false_positive_rate = 0.9 * self.false_positive_rate + 0.1
# Keep threshold in reasonable range
self.threshold = max(0.5, min(0.95, self.threshold))
Conclusion
ML security analytics requires:
- Real-time processing - Stream processing at scale
- Behavioral profiling - Understand normal to detect abnormal
- Efficient inference - Optimize for latency
- Threat intelligence - Enrich with external data
- Alert correlation - Reduce noise
- Adaptive thresholds - Learn from feedback
The key is balancing detection accuracy with operational burden. Too sensitive creates alert fatigue; too lenient misses threats.