As AI systems handle increasingly critical functions, security moves from nice-to-have to mission-critical. This post presents a comprehensive security framework for production AI systems, covering threat modeling, defense strategies, incident response, and compliance.
AI Security Threat Model
AI systems face unique threats beyond traditional software:
Data Poisoning: Malicious training data corrupts model behavior Model Extraction: Attackers reconstruct proprietary models through queries Prompt Injection: Adversarial inputs bypass safety controls Model Inversion: Extracting training data from model outputs Adversarial Examples: Inputs designed to fool the model Supply Chain Attacks: Compromised dependencies or model weights Resource Exhaustion: DoS through expensive queries
Defense-in-Depth Architecture
βββββββββββββββββββββββββββββββββββββββββββ
β Layer 1: Network & Infrastructure β
β - DDoS protection β
β - Network segmentation β
β - TLS encryption β
βββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββ
β Layer 2: Authentication & Authorizationβ
β - OAuth/OIDC β
β - API key management β
β - Role-based access control β
βββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββ
β Layer 3: Input Validation β
β - Schema validation β
β - Injection detection β
β - Rate limiting β
βββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββ
β Layer 4: Model Security β
β - Adversarial detection β
β - Output filtering β
β - Confidence thresholds β
βββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββ
β Layer 5: Monitoring & Response β
β - Anomaly detection β
β - Audit logging β
β - Incident response β
βββββββββββββββββββββββββββββββββββββββββββ
Implementation
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
from enum import Enum
import hashlib
import time
class ThreatLevel(Enum):
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
@dataclass
class SecurityEvent:
timestamp: float
event_type: str
threat_level: ThreatLevel
source_ip: str
user_id: Optional[str]
details: Dict
mitigations_applied: List[str]
class AISecurityFramework:
"""Comprehensive security framework for AI systems"""
def __init__(self, config: Dict):
self.config = config
self.threat_detector = ThreatDetector()
self.input_validator = InputValidator()
self.output_filter = OutputFilter()
self.audit_log = AuditLog()
self.rate_limiter = RateLimiter()
self.blocked_ips: Set[str] = set()
async def secure_inference(
self,
request: Dict,
context: Dict
) -> Dict:
"""Execute inference with full security stack"""
security_context = {
'request_id': request.get('id', str(uuid.uuid4())),
'user_id': context.get('user_id'),
'source_ip': context.get('source_ip'),
'timestamp': time.time()
}
# Layer 1: IP blocking
if self._is_blocked(security_context['source_ip']):
raise SecurityError("IP address blocked")
# Layer 2: Authentication
await self._authenticate(context)
# Layer 3: Authorization
await self._authorize(context, request)
# Layer 4: Rate limiting
await self.rate_limiter.check_rate_limit(
security_context['user_id'],
security_context['source_ip']
)
# Layer 5: Input validation
validation_result = await self.input_validator.validate(
request['input']
)
if not validation_result['safe']:
await self._handle_threat(
ThreatLevel.HIGH,
"input_validation_failed",
security_context,
validation_result
)
raise SecurityError("Input validation failed")
# Layer 6: Threat detection
threat_assessment = await self.threat_detector.assess(
request['input'],
security_context
)
if threat_assessment['level'].value >= ThreatLevel.HIGH.value:
await self._handle_threat(
threat_assessment['level'],
"threat_detected",
security_context,
threat_assessment
)
raise SecurityError("Potential threat detected")
# Execute inference
try:
output = await self._execute_inference(request['input'])
except Exception as e:
await self.audit_log.log_event({
**security_context,
'event': 'inference_error',
'error': str(e)
})
raise
# Layer 7: Output filtering
filtered_output = await self.output_filter.filter(
output,
security_context
)
# Layer 8: Audit logging
await self.audit_log.log_inference({
**security_context,
'input_hash': self._hash_input(request['input']),
'output_hash': self._hash_input(filtered_output['content']),
'threat_level': threat_assessment['level'].value,
'filters_applied': filtered_output.get('filters_applied', [])
})
return filtered_output
def _is_blocked(self, ip: str) -> bool:
"""Check if IP is blocked"""
return ip in self.blocked_ips
async def _authenticate(self, context: Dict):
"""Verify authentication credentials"""
if 'api_key' not in context and 'oauth_token' not in context:
raise AuthenticationError("Missing credentials")
# Validate API key or OAuth token
# Implementation depends on auth provider
async def _authorize(self, context: Dict, request: Dict):
"""Check authorization for request"""
user_permissions = await self._get_user_permissions(
context.get('user_id')
)
required_permission = self._get_required_permission(request)
if required_permission not in user_permissions:
raise AuthorizationError("Insufficient permissions")
async def _handle_threat(
self,
level: ThreatLevel,
threat_type: str,
context: Dict,
details: Dict
):
"""Handle detected threat"""
event = SecurityEvent(
timestamp=time.time(),
event_type=threat_type,
threat_level=level,
source_ip=context['source_ip'],
user_id=context.get('user_id'),
details=details,
mitigations_applied=[]
)
# Apply mitigations based on threat level
if level == ThreatLevel.CRITICAL:
# Block IP immediately
self.blocked_ips.add(context['source_ip'])
event.mitigations_applied.append('ip_blocked')
# Alert security team
await self._alert_security_team(event)
event.mitigations_applied.append('security_alerted')
elif level == ThreatLevel.HIGH:
# Increase monitoring
await self.rate_limiter.reduce_limit(
context['source_ip'],
factor=0.5
)
event.mitigations_applied.append('rate_limit_reduced')
# Log event
await self.audit_log.log_security_event(event)
def _hash_input(self, data: str) -> str:
"""Create hash of input/output for audit trail"""
return hashlib.sha256(data.encode()).hexdigest()[:16]
async def _execute_inference(self, input_data: str) -> str:
"""Execute actual model inference"""
# Placeholder - implement actual inference
pass
class ThreatDetector:
"""Detect potential security threats in inputs"""
def __init__(self):
self.patterns = self._load_threat_patterns()
self.ml_detector = AdversarialDetector()
async def assess(self, input_text: str, context: Dict) -> Dict:
"""Assess threat level of input"""
threats = []
# Pattern-based detection
for pattern in self.patterns:
if pattern.matches(input_text):
threats.append({
'type': pattern.threat_type,
'level': pattern.level,
'pattern': pattern.name
})
# ML-based adversarial detection
adversarial_score = await self.ml_detector.score(input_text)
if adversarial_score > 0.7:
threats.append({
'type': 'adversarial_input',
'level': ThreatLevel.HIGH,
'score': adversarial_score
})
# Behavioral analysis
behavioral_threats = await self._behavioral_analysis(
input_text,
context
)
threats.extend(behavioral_threats)
# Determine overall threat level
max_level = max(
(t['level'] for t in threats),
default=ThreatLevel.NONE
)
return {
'level': max_level,
'threats': threats,
'score': adversarial_score
}
async def _behavioral_analysis(
self,
input_text: str,
context: Dict
) -> List[Dict]:
"""Analyze behavioral patterns"""
threats = []
# Check for rapid-fire requests (potential model extraction)
user_id = context.get('user_id')
if user_id:
recent_requests = await self._get_recent_requests(user_id)
if len(recent_requests) > 100 and time.time() - recent_requests[0] < 300:
threats.append({
'type': 'potential_model_extraction',
'level': ThreatLevel.MEDIUM,
'request_count': len(recent_requests)
})
# Check for similar inputs (potential probing)
similarity_score = await self._check_input_similarity(
input_text,
context
)
if similarity_score > 0.9:
threats.append({
'type': 'potential_probing',
'level': ThreatLevel.LOW,
'similarity': similarity_score
})
return threats
class RateLimiter:
"""Token bucket rate limiter"""
def __init__(self, default_rate: int = 100, window: int = 60):
self.default_rate = default_rate
self.window = window
self.buckets: Dict[str, Dict] = {}
async def check_rate_limit(self, user_id: str, ip: str):
"""Check if request within rate limit"""
key = f"{user_id}:{ip}"
if key not in self.buckets:
self.buckets[key] = {
'tokens': self.default_rate,
'last_update': time.time(),
'rate': self.default_rate
}
bucket = self.buckets[key]
# Refill tokens
now = time.time()
elapsed = now - bucket['last_update']
refill = (elapsed / self.window) * bucket['rate']
bucket['tokens'] = min(
bucket['rate'],
bucket['tokens'] + refill
)
bucket['last_update'] = now
# Check if tokens available
if bucket['tokens'] < 1:
raise RateLimitError("Rate limit exceeded")
bucket['tokens'] -= 1
async def reduce_limit(self, identifier: str, factor: float):
"""Reduce rate limit for identifier"""
if identifier in self.buckets:
self.buckets[identifier]['rate'] *= factor
class AuditLog:
"""Comprehensive audit logging"""
def __init__(self, storage):
self.storage = storage
async def log_inference(self, data: Dict):
"""Log inference request"""
await self.storage.append('audit_log:inference', {
**data,
'type': 'inference',
'timestamp': time.time()
})
async def log_security_event(self, event: SecurityEvent):
"""Log security event"""
await self.storage.append('audit_log:security', {
'timestamp': event.timestamp,
'event_type': event.event_type,
'threat_level': event.threat_level.name,
'source_ip': event.source_ip,
'user_id': event.user_id,
'details': event.details,
'mitigations': event.mitigations_applied
})
async def query_events(
self,
event_type: Optional[str] = None,
min_threat_level: Optional[ThreatLevel] = None,
time_range: Optional[tuple] = None,
limit: int = 1000
) -> List[Dict]:
"""Query audit logs"""
events = await self.storage.get('audit_log:security', limit=limit)
if event_type:
events = [e for e in events if e['event_type'] == event_type]
if min_threat_level:
events = [
e for e in events
if ThreatLevel[e['threat_level']].value >= min_threat_level.value
]
if time_range:
start, end = time_range
events = [
e for e in events
if start <= e['timestamp'] <= end
]
return events
class ComplianceManager:
"""Manage regulatory compliance requirements"""
def __init__(self, requirements: List[str]):
self.requirements = requirements # e.g., ['GDPR', 'CCPA', 'SOC2']
async def ensure_compliance(self, request: Dict, response: Dict):
"""Verify compliance for request/response"""
violations = []
for requirement in self.requirements:
check_method = getattr(self, f'_check_{requirement.lower()}', None)
if check_method:
result = await check_method(request, response)
if not result['compliant']:
violations.append({
'requirement': requirement,
'issues': result['issues']
})
if violations:
raise ComplianceError(f"Compliance violations: {violations}")
async def _check_gdpr(self, request: Dict, response: Dict) -> Dict:
"""Check GDPR compliance"""
issues = []
# Check for PII in response
if self._contains_pii(response):
issues.append("Response contains PII without consent")
# Check data retention
# Implementation specific to requirements
return {
'compliant': len(issues) == 0,
'issues': issues
}
def _contains_pii(self, data: Dict) -> bool:
"""Check if data contains PII"""
# Implementation using PII detector
return False
class IncidentResponsePlan:
"""Automated incident response"""
def __init__(self):
self.playbooks = self._load_playbooks()
async def execute(self, event: SecurityEvent):
"""Execute incident response for security event"""
playbook = self._select_playbook(event)
if not playbook:
return
# Execute response steps
for step in playbook['steps']:
await self._execute_step(step, event)
def _select_playbook(self, event: SecurityEvent) -> Optional[Dict]:
"""Select appropriate response playbook"""
for playbook in self.playbooks:
if self._matches_criteria(playbook, event):
return playbook
return None
async def _execute_step(self, step: Dict, event: SecurityEvent):
"""Execute response step"""
action = step['action']
if action == 'block_ip':
# Block source IP
pass
elif action == 'alert':
# Send alert
await self._send_alert(step['recipients'], event)
elif action == 'isolate':
# Isolate affected resources
pass
async def _send_alert(self, recipients: List[str], event: SecurityEvent):
"""Send security alert"""
# Implementation for alerting (email, Slack, PagerDuty, etc.)
pass
Conclusion
Securing AI systems requires a comprehensive, multi-layered approach. From network security to model-specific defenses, each layer provides critical protection. As AI systems become more autonomous and handle sensitive data, robust security frameworks are essential for production deployments.
The key is defense in depth: no single layer is perfect, but together they create a resilient security posture that can withstand sophisticated attacks.