Training large AI models requires distributed compute. This post explores practical distributed training patterns, from basic data parallelism to advanced techniques like pipeline and tensor parallelism.
Distributed Training Patterns
Three primary parallelization strategies:
Data Parallelism: Replicate model across GPUs, split data Pipeline Parallelism: Split model layers across GPUs Tensor Parallelism: Split individual layers across GPUs
Data Parallel Training
The most common approach:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
class DistributedTrainer:
"""Data parallel distributed training"""
def __init__(
self,
model: torch.nn.Module,
train_dataset,
config: Dict
):
# Initialize process group
self.rank = int(os.environ['RANK'])
self.world_size = int(os.environ['WORLD_SIZE'])
self.local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend='nccl', # Use NCCL for GPU
init_method='env://'
)
# Set device
torch.cuda.set_device(self.local_rank)
self.device = torch.device(f'cuda:{self.local_rank}')
# Wrap model in DDP
self.model = model.to(self.device)
self.model = DDP(
self.model,
device_ids=[self.local_rank],
output_device=self.local_rank,
find_unused_parameters=False # Optimization
)
# Distributed sampler
self.train_sampler = DistributedSampler(
train_dataset,
num_replicas=self.world_size,
rank=self.rank,
shuffle=True
)
# DataLoader
self.train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
sampler=self.train_sampler,
num_workers=4,
pin_memory=True
)
# Optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=config['learning_rate']
)
# Gradient scaler for mixed precision
self.scaler = torch.cuda.amp.GradScaler()
async def train_epoch(self, epoch: int):
"""Train one epoch"""
self.model.train()
self.train_sampler.set_epoch(epoch) # Shuffle differently each epoch
total_loss = 0
step_count = 0
for batch in self.train_loader:
# Move to device
inputs = batch['input'].to(self.device, non_blocking=True)
labels = batch['label'].to(self.device, non_blocking=True)
# Forward pass with automatic mixed precision
with torch.cuda.amp.autocast():
outputs = self.model(inputs)
loss = self._compute_loss(outputs, labels)
# Backward pass
self.optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad()
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Optimizer step
self.scaler.step(self.optimizer)
self.scaler.update()
# Track loss (average across processes)
total_loss += loss.item()
step_count += 1
# Average loss across all processes
avg_loss = total_loss / step_count
avg_loss_tensor = torch.tensor(avg_loss).to(self.device)
dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.AVG)
return avg_loss_tensor.item()
def save_checkpoint(self, epoch: int, path: str):
"""Save training checkpoint (only rank 0)"""
if self.rank == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.module.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scaler_state_dict': self.scaler.state_dict()
}
torch.save(checkpoint, path)
def cleanup(self):
"""Cleanup distributed training"""
dist.destroy_process_group()
Pipeline Parallelism
For very large models that don’t fit on single GPU:
from torch.distributed.pipeline.sync import Pipe
class PipelineParallelTrainer:
"""Pipeline parallel training for large models"""
def __init__(
self,
model: torch.nn.Module,
num_gpus: int,
chunks: int = 8
):
# Split model across GPUs
self.model = self._partition_model(model, num_gpus)
# Wrap in Pipeline
self.pipeline = Pipe(
self.model,
chunks=chunks, # Micro-batch size
checkpoint='except_last' # Checkpointing strategy
)
def _partition_model(
self,
model: torch.nn.Module,
num_gpus: int
) -> torch.nn.Sequential:
"""Partition model layers across GPUs"""
layers = list(model.children())
layers_per_gpu = len(layers) // num_gpus
partitions = []
for i in range(num_gpus):
start_idx = i * layers_per_gpu
end_idx = start_idx + layers_per_gpu if i < num_gpus - 1 else len(layers)
partition = torch.nn.Sequential(*layers[start_idx:end_idx])
partition = partition.to(f'cuda:{i}')
partitions.append(partition)
return torch.nn.Sequential(*partitions)
async def train_step(self, inputs, labels):
"""Execute pipeline parallel training step"""
# Forward pass through pipeline
outputs = self.pipeline(inputs)
# Compute loss
loss = F.cross_entropy(outputs, labels)
# Backward pass
loss.backward()
return loss.item()
Gradient Accumulation
Simulate larger batch sizes:
class GradientAccumulationTrainer:
"""Training with gradient accumulation"""
def __init__(
self,
model,
optimizer,
accumulation_steps: int = 4
):
self.model = model
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
async def train_step(
self,
dataloader,
scaler,
device
):
"""Training step with gradient accumulation"""
self.model.train()
accumulated_loss = 0
for step, batch in enumerate(dataloader):
inputs = batch['input'].to(device)
labels = batch['label'].to(device)
# Normalize loss to account for accumulation
with torch.cuda.amp.autocast():
outputs = self.model(inputs)
loss = F.cross_entropy(outputs, labels)
loss = loss / self.accumulation_steps
# Backward pass
scaler.scale(loss).backward()
accumulated_loss += loss.item()
# Update weights every N steps
if (step + 1) % self.accumulation_steps == 0:
scaler.step(self.optimizer)
scaler.update()
self.optimizer.zero_grad()
return accumulated_loss
class ZeROOptimizer:
"""ZeRO (Zero Redundancy Optimizer) for memory efficiency"""
def __init__(
self,
model,
optimizer_config: Dict,
stage: int = 2
):
"""
Stage 1: Partition optimizer states
Stage 2: Partition optimizer states + gradients
Stage 3: Partition optimizer states + gradients + parameters
"""
from deepspeed import DeepSpeedConfig, DeepSpeedEngine
ds_config = {
"train_batch_size": optimizer_config['batch_size'],
"gradient_accumulation_steps": optimizer_config.get('accumulation_steps', 1),
"optimizer": {
"type": "AdamW",
"params": {
"lr": optimizer_config['learning_rate'],
"weight_decay": optimizer_config.get('weight_decay', 0.01)
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": stage,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
}
}
self.model_engine, self.optimizer, _, _ = DeepSpeedEngine(
model=model,
model_parameters=model.parameters(),
config=ds_config
)
async def train_step(self, inputs, labels):
"""Training step with ZeRO"""
outputs = self.model_engine(inputs)
loss = F.cross_entropy(outputs, labels)
self.model_engine.backward(loss)
self.model_engine.step()
return loss.item()
Communication Optimization
Efficient collective operations:
class CommunicationOptimizer:
"""Optimize distributed communication"""
@staticmethod
async def overlapped_allreduce(
model: DDP,
optimizer,
loss
):
"""Overlap gradient communication with backward pass"""
# This is automatic with DDP's find_unused_parameters=False
# but can be manually controlled for fine-grained optimization
optimizer.zero_grad()
# Backward pass (gradients communicated automatically during backward)
loss.backward()
# Gradients already all-reduced by DDP
optimizer.step()
@staticmethod
async def gradient_compression(
gradients: torch.Tensor,
compression_ratio: float = 0.01
) -> torch.Tensor:
"""Compress gradients before communication"""
# Top-k sparsification
k = int(gradients.numel() * compression_ratio)
# Get top-k values and indices
values, indices = torch.topk(
gradients.abs().flatten(),
k
)
# Create sparse gradient
compressed = torch.zeros_like(gradients.flatten())
compressed[indices] = gradients.flatten()[indices]
return compressed.reshape(gradients.shape)
@staticmethod
async def hierarchical_allreduce(
tensor: torch.Tensor,
local_group,
global_group
):
"""Two-level hierarchical all-reduce"""
# Reduce within node
dist.all_reduce(tensor, group=local_group)
# Reduce across nodes (only one rank per node)
if dist.get_rank(local_group) == 0:
dist.all_reduce(tensor, group=global_group)
# Broadcast result within node
dist.broadcast(tensor, src=0, group=local_group)
Monitoring and Debugging
Track distributed training health:
class DistributedMonitor:
"""Monitor distributed training metrics"""
def __init__(self, rank: int, world_size: int):
self.rank = rank
self.world_size = world_size
self.metrics = {
'gpu_utilization': [],
'communication_time': [],
'computation_time': [],
'memory_usage': []
}
async def log_step_metrics(self):
"""Log metrics for current step"""
# GPU utilization
if torch.cuda.is_available():
utilization = torch.cuda.utilization(self.rank)
self.metrics['gpu_utilization'].append(utilization)
# Memory usage
memory_allocated = torch.cuda.memory_allocated(self.rank) / 1e9 # GB
self.metrics['memory_usage'].append(memory_allocated)
def get_summary(self) -> Dict:
"""Get training summary"""
import numpy as np
summary = {}
for metric_name, values in self.metrics.items():
if values:
summary[metric_name] = {
'mean': np.mean(values),
'std': np.std(values),
'min': np.min(values),
'max': np.max(values)
}
return summary
async def check_gpu_balance(self) -> Dict:
"""Check if work is balanced across GPUs"""
# Gather utilization from all ranks
local_utilization = torch.tensor([
torch.cuda.utilization(self.rank)
]).cuda()
# Gather to rank 0
if self.rank == 0:
all_utilizations = [
torch.zeros(1).cuda()
for _ in range(self.world_size)
]
dist.gather(local_utilization, all_utilizations, dst=0)
utilizations = [u.item() for u in all_utilizations]
return {
'balanced': max(utilizations) - min(utilizations) < 20, # 20% threshold
'utilizations': utilizations,
'imbalance_score': max(utilizations) - min(utilizations)
}
else:
dist.gather(local_utilization, dst=0)
return {}
Fault Tolerance
Handle failures in distributed training:
class FaultTolerantTrainer:
"""Distributed training with fault tolerance"""
def __init__(
self,
model,
train_loader,
checkpoint_frequency: int = 100
):
self.model = model
self.train_loader = train_loader
self.checkpoint_frequency = checkpoint_frequency
self.global_step = 0
async def train_with_recovery(
self,
num_epochs: int,
checkpoint_dir: str
):
"""Training with automatic checkpoint and recovery"""
# Try to recover from latest checkpoint
start_epoch, start_step = self._load_latest_checkpoint(checkpoint_dir)
for epoch in range(start_epoch, num_epochs):
try:
await self._train_epoch(epoch, start_step if epoch == start_epoch else 0)
except Exception as e:
self._log_error(f"Training failed at epoch {epoch}: {e}")
# Save emergency checkpoint
self._save_checkpoint(
epoch,
self.global_step,
f"{checkpoint_dir}/emergency_{epoch}_{self.global_step}.pt"
)
raise
async def _train_epoch(self, epoch: int, start_step: int):
"""Train one epoch with checkpointing"""
for step, batch in enumerate(self.train_loader):
if step < start_step:
continue
# Training step
await self._train_step(batch)
self.global_step += 1
# Periodic checkpointing
if self.global_step % self.checkpoint_frequency == 0:
self._save_checkpoint(
epoch,
self.global_step,
f"checkpoint_e{epoch}_s{self.global_step}.pt"
)
def _load_latest_checkpoint(
self,
checkpoint_dir: str
) -> tuple:
"""Load latest checkpoint if exists"""
checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint_*.pt")
if not checkpoints:
return 0, 0
# Find latest checkpoint
latest = max(checkpoints, key=os.path.getctime)
checkpoint = torch.load(latest, map_location=f'cuda:{self.rank}')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['step']
Conclusion
Distributed training is essential for modern AI development. Data parallelism handles most use cases, while pipeline and tensor parallelism enable training of massive models. Communication optimization, gradient accumulation, and fault tolerance ensure efficient, reliable training at scale.
The key is matching the parallelization strategy to your model size, data size, and available hardware. As models continue to grow, mastering distributed training becomes increasingly important.