Security analytics has evolved from rule-based detection to sophisticated ML systems that can identify novel threats in real-time. This post explores architecting ML-powered security analytics at scale.

Streaming Analytics Architecture

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions

class SecurityEvent:
    def __init__(self, timestamp, source_ip, event_type, features):
        self.timestamp = timestamp
        self.source_ip = source_ip
        self.event_type = event_type
        self.features = features

class ThreatDetectionPipeline:
    def __init__(self, model):
        self.model = model

    def build_pipeline(self):
        options = PipelineOptions()

        with beam.Pipeline(options=options) as pipeline:
            (pipeline
             | 'Read Events' >> beam.io.ReadFromPubSub(topic='security-events')
             | 'Parse Events' >> beam.Map(self.parse_event)
             | 'Extract Features' >> beam.ParDo(FeatureExtractor())
             | 'Windowing' >> beam.WindowInto(beam.window.FixedWindows(60))
             | 'Aggregate' >> beam.CombinePerKey(BehavioralAggregator())
             | 'ML Inference' >> beam.ParDo(MLInference(self.model))
             | 'Filter Threats' >> beam.Filter(lambda x: x.risk_score > 0.8)
             | 'Alert' >> beam.ParDo(AlertSender()))

class FeatureExtractor(beam.DoFn):
    def process(self, event):
        features = {
            'hour_of_day': event.timestamp.hour,
            'day_of_week': event.timestamp.weekday(),
            'event_type_encoded': self.encode_event_type(event.event_type),
            'source_ip_reputation': self.get_ip_reputation(event.source_ip),
        }
        yield (event.source_ip, (event, features))

class MLInference(beam.DoFn):
    def __init__(self, model):
        self.model = model

    def process(self, element):
        source_ip, (event, features) = element
        risk_score = self.model.predict([list(features.values())])[0]

        yield {
            'source_ip': source_ip,
            'event': event,
            'risk_score': risk_score,
            'features': features
        }

Anomaly Detection Models

from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
import numpy as np

class AnomalyDetector:
    def __init__(self):
        self.scaler = StandardScaler()
        self.model = IsolationForest(
            contamination=0.01,  # Expect 1% anomalies
            n_estimators=100,
            max_samples=256,
            random_state=42
        )
        self.is_fitted = False

    def fit(self, normal_traffic):
        """Train on known-good traffic"""
        features = self.extract_features(normal_traffic)
        scaled = self.scaler.fit_transform(features)
        self.model.fit(scaled)
        self.is_fitted = True

    def predict(self, events):
        """Detect anomalies in new events"""
        features = self.extract_features(events)
        scaled = self.scaler.transform(features)

        scores = self.model.decision_function(scaled)
        predictions = self.model.predict(scaled)

        # -1 = anomaly, 1 = normal
        anomalies = []
        for i, pred in enumerate(predictions):
            if pred == -1:
                anomalies.append({
                    'event': events[i],
                    'anomaly_score': -scores[i],  # Higher = more anomalous
                    'features': features[i]
                })

        return anomalies

    def extract_features(self, events):
        """Extract numerical features from events"""
        features = []
        for event in events:
            features.append([
                event.packet_size,
                event.packets_per_second,
                event.unique_destinations,
                event.failed_attempts,
                self.time_since_last_event(event),
            ])
        return np.array(features)

Behavioral Profiling

from collections import defaultdict
from datetime import datetime, timedelta

class BehavioralProfile:
    def __init__(self, window_hours=24):
        self.window = timedelta(hours=window_hours)
        self.profiles = defaultdict(lambda: {
            'events': [],
            'baseline': None
        })

    def update(self, entity_id, event):
        """Update behavioral profile"""
        profile = self.profiles[entity_id]

        # Add event
        profile['events'].append(event)

        # Remove old events outside window
        cutoff = datetime.now() - self.window
        profile['events'] = [
            e for e in profile['events']
            if e.timestamp > cutoff
        ]

        # Update baseline if enough data
        if len(profile['events']) > 100:
            profile['baseline'] = self._compute_baseline(profile['events'])

    def detect_deviation(self, entity_id, new_event):
        """Detect if new event deviates from baseline"""
        profile = self.profiles[entity_id]

        if not profile['baseline']:
            return 0.0  # No baseline yet

        current_behavior = self._compute_behavior([new_event])
        baseline = profile['baseline']

        # Calculate deviation
        deviation = self._calculate_deviation(current_behavior, baseline)
        return deviation

    def _compute_baseline(self, events):
        """Compute baseline behavior from historical events"""
        return {
            'avg_requests_per_hour': len(events) / 24,
            'avg_bytes_per_request': np.mean([e.bytes for e in events]),
            'unique_destinations': len(set(e.destination for e in events)),
            'common_ports': Counter(e.port for e in events).most_common(5),
            'time_distribution': self._compute_time_distribution(events)
        }

    def _calculate_deviation(self, current, baseline):
        """Calculate deviation score"""
        deviations = []

        # Request rate deviation
        if baseline['avg_requests_per_hour'] > 0:
            rate_dev = abs(current['avg_requests_per_hour'] - baseline['avg_requests_per_hour']) / baseline['avg_requests_per_hour']
            deviations.append(rate_dev)

        # Byte size deviation
        if baseline['avg_bytes_per_request'] > 0:
            bytes_dev = abs(current['avg_bytes_per_request'] - baseline['avg_bytes_per_request']) / baseline['avg_bytes_per_request']
            deviations.append(bytes_dev)

        return np.mean(deviations)

Real-Time Inference Optimization

// High-performance feature extraction in Rust
use ndarray::{Array1, Array2};

pub struct FastFeatureExtractor {
    scaler_mean: Array1<f32>,
    scaler_std: Array1<f32>,
}

impl FastFeatureExtractor {
    pub fn extract_and_normalize(&self, events: &[Event]) -> Array2<f32> {
        let n = events.len();
        let mut features = Array2::<f32>::zeros((n, 10));

        for (i, event) in events.iter().enumerate() {
            features[[i, 0]] = event.packet_size as f32;
            features[[i, 1]] = event.packets_per_second;
            features[[i, 2]] = event.bytes_transferred as f32;
            features[[i, 3]] = event.duration_ms;
            features[[i, 4]] = event.unique_ports as f32;
            features[[i, 5]] = event.failed_attempts as f32;
            features[[i, 6]] = event.hour_of_day as f32;
            features[[i, 7]] = event.day_of_week as f32;
            features[[i, 8]] = if event.is_encrypted { 1.0 } else { 0.0 };
            features[[i, 9]] = event.protocol_id as f32;
        }

        // Normalize
        for i in 0..n {
            for j in 0..10 {
                features[[i, j]] = (features[[i, j]] - self.scaler_mean[j]) / self.scaler_std[j];
            }
        }

        features
    }
}

// Batch inference with ONNX Runtime
use ort::{Session, Value};

pub struct ONNXInferenceEngine {
    session: Session,
}

impl ONNXInferenceEngine {
    pub fn predict_batch(&self, features: Array2<f32>) -> Vec<f32> {
        let input = Value::from_array(self.session.allocator(), &features).unwrap();

        let outputs = self.session.run(vec![input]).unwrap();
        let output = outputs[0].try_extract::<f32>().unwrap();

        output.view().to_owned().into_raw_vec()
    }
}

Threat Intelligence Integration

class ThreatIntelligence:
    def __init__(self, redis_client):
        self.redis = redis_client
        self.update_interval = 3600  # 1 hour

    async def enrich_event(self, event):
        """Enrich event with threat intelligence"""

        enrichment = {}

        # IP reputation
        enrichment['ip_reputation'] = await self.get_ip_reputation(event.source_ip)

        # Domain reputation
        if event.domain:
            enrichment['domain_reputation'] = await self.get_domain_reputation(event.domain)

        # File hash check
        if event.file_hash:
            enrichment['file_reputation'] = await self.get_file_reputation(event.file_hash)

        # Known attack patterns
        enrichment['known_patterns'] = await self.match_attack_patterns(event)

        return {**event.__dict__, **enrichment}

    async def get_ip_reputation(self, ip):
        """Get IP reputation score from cache or external sources"""

        # Check cache
        cached = await self.redis.get(f"ip_rep:{ip}")
        if cached:
            return float(cached)

        # Fetch from threat feed
        reputation = await self._fetch_ip_reputation(ip)

        # Cache for future lookups
        await self.redis.setex(f"ip_rep:{ip}", 3600, reputation)

        return reputation

    async def _fetch_ip_reputation(self, ip):
        """Fetch from external threat intelligence sources"""
        # Query AbuseIPDB, VirusTotal, etc.
        # Aggregate scores
        return 0.5  # Placeholder

Alert Correlation and Deduplication

class AlertCorrelator:
    def __init__(self, time_window=300):  # 5 minutes
        self.time_window = time_window
        self.recent_alerts = []

    def process_alert(self, alert):
        """Correlate with recent alerts"""

        # Clean old alerts
        cutoff = time.time() - self.time_window
        self.recent_alerts = [a for a in self.recent_alerts if a['timestamp'] > cutoff]

        # Find related alerts
        related = self._find_related_alerts(alert)

        if related:
            # Correlate into incident
            incident = self._create_incident(alert, related)
            return incident
        else:
            # New alert
            self.recent_alerts.append(alert)
            return alert

    def _find_related_alerts(self, alert):
        """Find alerts that should be correlated"""
        related = []

        for existing in self.recent_alerts:
            # Same source IP
            if existing.get('source_ip') == alert.get('source_ip'):
                related.append(existing)
                continue

            # Same attack pattern
            if existing.get('attack_type') == alert.get('attack_type'):
                if self._are_targets_related(existing.get('target'), alert.get('target')):
                    related.append(existing)

        return related

    def _create_incident(self, new_alert, related_alerts):
        """Create correlated incident"""
        all_alerts = related_alerts + [new_alert]

        return {
            'incident_id': self._generate_incident_id(),
            'severity': max(a.get('severity', 0) for a in all_alerts),
            'alert_count': len(all_alerts),
            'first_seen': min(a['timestamp'] for a in all_alerts),
            'last_seen': max(a['timestamp'] for a in all_alerts),
            'affected_assets': list(set(a.get('target') for a in all_alerts)),
            'attack_type': new_alert.get('attack_type'),
            'source_ips': list(set(a.get('source_ip') for a in all_alerts)),
        }

Adaptive Thresholds

class AdaptiveThreshold:
    def __init__(self, initial_threshold=0.8, sensitivity=0.1):
        self.threshold = initial_threshold
        self.sensitivity = sensitivity
        self.recent_scores = []
        self.false_positive_rate = 0.0

    def should_alert(self, risk_score):
        """Determine if score warrants alert"""
        self.recent_scores.append(risk_score)

        if len(self.recent_scores) > 1000:
            self.recent_scores = self.recent_scores[-1000:]

        return risk_score > self.threshold

    def update_from_feedback(self, was_true_positive):
        """Adjust threshold based on analyst feedback"""

        if was_true_positive:
            # Lower threshold slightly to catch similar events
            self.threshold *= (1 - self.sensitivity)
        else:
            # Raise threshold to reduce false positives
            self.threshold *= (1 + self.sensitivity)
            self.false_positive_rate = 0.9 * self.false_positive_rate + 0.1

        # Keep threshold in reasonable range
        self.threshold = max(0.5, min(0.95, self.threshold))

Conclusion

ML security analytics requires:

  1. Real-time processing - Stream processing at scale
  2. Behavioral profiling - Understand normal to detect abnormal
  3. Efficient inference - Optimize for latency
  4. Threat intelligence - Enrich with external data
  5. Alert correlation - Reduce noise
  6. Adaptive thresholds - Learn from feedback

The key is balancing detection accuracy with operational burden. Too sensitive creates alert fatigue; too lenient misses threats.