Training large AI models requires distributed compute. This post explores practical distributed training patterns, from basic data parallelism to advanced techniques like pipeline and tensor parallelism.

Distributed Training Patterns

Three primary parallelization strategies:

Data Parallelism: Replicate model across GPUs, split data Pipeline Parallelism: Split model layers across GPUs Tensor Parallelism: Split individual layers across GPUs

Data Parallel Training

The most common approach:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

class DistributedTrainer:
    """Data parallel distributed training"""

    def __init__(
        self,
        model: torch.nn.Module,
        train_dataset,
        config: Dict
    ):
        # Initialize process group
        self.rank = int(os.environ['RANK'])
        self.world_size = int(os.environ['WORLD_SIZE'])
        self.local_rank = int(os.environ['LOCAL_RANK'])

        dist.init_process_group(
            backend='nccl',  # Use NCCL for GPU
            init_method='env://'
        )

        # Set device
        torch.cuda.set_device(self.local_rank)
        self.device = torch.device(f'cuda:{self.local_rank}')

        # Wrap model in DDP
        self.model = model.to(self.device)
        self.model = DDP(
            self.model,
            device_ids=[self.local_rank],
            output_device=self.local_rank,
            find_unused_parameters=False  # Optimization
        )

        # Distributed sampler
        self.train_sampler = DistributedSampler(
            train_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True
        )

        # DataLoader
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            sampler=self.train_sampler,
            num_workers=4,
            pin_memory=True
        )

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config['learning_rate']
        )

        # Gradient scaler for mixed precision
        self.scaler = torch.cuda.amp.GradScaler()

    async def train_epoch(self, epoch: int):
        """Train one epoch"""

        self.model.train()
        self.train_sampler.set_epoch(epoch)  # Shuffle differently each epoch

        total_loss = 0
        step_count = 0

        for batch in self.train_loader:
            # Move to device
            inputs = batch['input'].to(self.device, non_blocking=True)
            labels = batch['label'].to(self.device, non_blocking=True)

            # Forward pass with automatic mixed precision
            with torch.cuda.amp.autocast():
                outputs = self.model(inputs)
                loss = self._compute_loss(outputs, labels)

            # Backward pass
            self.optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
            self.scaler.scale(loss).backward()

            # Gradient clipping
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Optimizer step
            self.scaler.step(self.optimizer)
            self.scaler.update()

            # Track loss (average across processes)
            total_loss += loss.item()
            step_count += 1

        # Average loss across all processes
        avg_loss = total_loss / step_count
        avg_loss_tensor = torch.tensor(avg_loss).to(self.device)
        dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.AVG)

        return avg_loss_tensor.item()

    def save_checkpoint(self, epoch: int, path: str):
        """Save training checkpoint (only rank 0)"""

        if self.rank == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': self.model.module.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scaler_state_dict': self.scaler.state_dict()
            }

            torch.save(checkpoint, path)

    def cleanup(self):
        """Cleanup distributed training"""
        dist.destroy_process_group()

Pipeline Parallelism

For very large models that don’t fit on single GPU:

from torch.distributed.pipeline.sync import Pipe

class PipelineParallelTrainer:
    """Pipeline parallel training for large models"""

    def __init__(
        self,
        model: torch.nn.Module,
        num_gpus: int,
        chunks: int = 8
    ):
        # Split model across GPUs
        self.model = self._partition_model(model, num_gpus)

        # Wrap in Pipeline
        self.pipeline = Pipe(
            self.model,
            chunks=chunks,  # Micro-batch size
            checkpoint='except_last'  # Checkpointing strategy
        )

    def _partition_model(
        self,
        model: torch.nn.Module,
        num_gpus: int
    ) -> torch.nn.Sequential:
        """Partition model layers across GPUs"""

        layers = list(model.children())
        layers_per_gpu = len(layers) // num_gpus

        partitions = []

        for i in range(num_gpus):
            start_idx = i * layers_per_gpu
            end_idx = start_idx + layers_per_gpu if i < num_gpus - 1 else len(layers)

            partition = torch.nn.Sequential(*layers[start_idx:end_idx])
            partition = partition.to(f'cuda:{i}')

            partitions.append(partition)

        return torch.nn.Sequential(*partitions)

    async def train_step(self, inputs, labels):
        """Execute pipeline parallel training step"""

        # Forward pass through pipeline
        outputs = self.pipeline(inputs)

        # Compute loss
        loss = F.cross_entropy(outputs, labels)

        # Backward pass
        loss.backward()

        return loss.item()

Gradient Accumulation

Simulate larger batch sizes:

class GradientAccumulationTrainer:
    """Training with gradient accumulation"""

    def __init__(
        self,
        model,
        optimizer,
        accumulation_steps: int = 4
    ):
        self.model = model
        self.optimizer = optimizer
        self.accumulation_steps = accumulation_steps

    async def train_step(
        self,
        dataloader,
        scaler,
        device
    ):
        """Training step with gradient accumulation"""

        self.model.train()
        accumulated_loss = 0

        for step, batch in enumerate(dataloader):
            inputs = batch['input'].to(device)
            labels = batch['label'].to(device)

            # Normalize loss to account for accumulation
            with torch.cuda.amp.autocast():
                outputs = self.model(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss = loss / self.accumulation_steps

            # Backward pass
            scaler.scale(loss).backward()

            accumulated_loss += loss.item()

            # Update weights every N steps
            if (step + 1) % self.accumulation_steps == 0:
                scaler.step(self.optimizer)
                scaler.update()
                self.optimizer.zero_grad()

        return accumulated_loss


class ZeROOptimizer:
    """ZeRO (Zero Redundancy Optimizer) for memory efficiency"""

    def __init__(
        self,
        model,
        optimizer_config: Dict,
        stage: int = 2
    ):
        """
        Stage 1: Partition optimizer states
        Stage 2: Partition optimizer states + gradients
        Stage 3: Partition optimizer states + gradients + parameters
        """

        from deepspeed import DeepSpeedConfig, DeepSpeedEngine

        ds_config = {
            "train_batch_size": optimizer_config['batch_size'],
            "gradient_accumulation_steps": optimizer_config.get('accumulation_steps', 1),
            "optimizer": {
                "type": "AdamW",
                "params": {
                    "lr": optimizer_config['learning_rate'],
                    "weight_decay": optimizer_config.get('weight_decay', 0.01)
                }
            },
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "loss_scale_window": 1000,
                "hysteresis": 2,
                "min_loss_scale": 1
            },
            "zero_optimization": {
                "stage": stage,
                "contiguous_gradients": True,
                "overlap_comm": True,
                "reduce_scatter": True,
                "reduce_bucket_size": 5e8,
                "allgather_bucket_size": 5e8
            }
        }

        self.model_engine, self.optimizer, _, _ = DeepSpeedEngine(
            model=model,
            model_parameters=model.parameters(),
            config=ds_config
        )

    async def train_step(self, inputs, labels):
        """Training step with ZeRO"""

        outputs = self.model_engine(inputs)
        loss = F.cross_entropy(outputs, labels)

        self.model_engine.backward(loss)
        self.model_engine.step()

        return loss.item()

Communication Optimization

Efficient collective operations:

class CommunicationOptimizer:
    """Optimize distributed communication"""

    @staticmethod
    async def overlapped_allreduce(
        model: DDP,
        optimizer,
        loss
    ):
        """Overlap gradient communication with backward pass"""

        # This is automatic with DDP's find_unused_parameters=False
        # but can be manually controlled for fine-grained optimization

        optimizer.zero_grad()

        # Backward pass (gradients communicated automatically during backward)
        loss.backward()

        # Gradients already all-reduced by DDP
        optimizer.step()

    @staticmethod
    async def gradient_compression(
        gradients: torch.Tensor,
        compression_ratio: float = 0.01
    ) -> torch.Tensor:
        """Compress gradients before communication"""

        # Top-k sparsification
        k = int(gradients.numel() * compression_ratio)

        # Get top-k values and indices
        values, indices = torch.topk(
            gradients.abs().flatten(),
            k
        )

        # Create sparse gradient
        compressed = torch.zeros_like(gradients.flatten())
        compressed[indices] = gradients.flatten()[indices]

        return compressed.reshape(gradients.shape)

    @staticmethod
    async def hierarchical_allreduce(
        tensor: torch.Tensor,
        local_group,
        global_group
    ):
        """Two-level hierarchical all-reduce"""

        # Reduce within node
        dist.all_reduce(tensor, group=local_group)

        # Reduce across nodes (only one rank per node)
        if dist.get_rank(local_group) == 0:
            dist.all_reduce(tensor, group=global_group)

        # Broadcast result within node
        dist.broadcast(tensor, src=0, group=local_group)

Monitoring and Debugging

Track distributed training health:

class DistributedMonitor:
    """Monitor distributed training metrics"""

    def __init__(self, rank: int, world_size: int):
        self.rank = rank
        self.world_size = world_size
        self.metrics = {
            'gpu_utilization': [],
            'communication_time': [],
            'computation_time': [],
            'memory_usage': []
        }

    async def log_step_metrics(self):
        """Log metrics for current step"""

        # GPU utilization
        if torch.cuda.is_available():
            utilization = torch.cuda.utilization(self.rank)
            self.metrics['gpu_utilization'].append(utilization)

            # Memory usage
            memory_allocated = torch.cuda.memory_allocated(self.rank) / 1e9  # GB
            self.metrics['memory_usage'].append(memory_allocated)

    def get_summary(self) -> Dict:
        """Get training summary"""

        import numpy as np

        summary = {}

        for metric_name, values in self.metrics.items():
            if values:
                summary[metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values)
                }

        return summary

    async def check_gpu_balance(self) -> Dict:
        """Check if work is balanced across GPUs"""

        # Gather utilization from all ranks
        local_utilization = torch.tensor([
            torch.cuda.utilization(self.rank)
        ]).cuda()

        # Gather to rank 0
        if self.rank == 0:
            all_utilizations = [
                torch.zeros(1).cuda()
                for _ in range(self.world_size)
            ]
            dist.gather(local_utilization, all_utilizations, dst=0)

            utilizations = [u.item() for u in all_utilizations]

            return {
                'balanced': max(utilizations) - min(utilizations) < 20,  # 20% threshold
                'utilizations': utilizations,
                'imbalance_score': max(utilizations) - min(utilizations)
            }
        else:
            dist.gather(local_utilization, dst=0)
            return {}

Fault Tolerance

Handle failures in distributed training:

class FaultTolerantTrainer:
    """Distributed training with fault tolerance"""

    def __init__(
        self,
        model,
        train_loader,
        checkpoint_frequency: int = 100
    ):
        self.model = model
        self.train_loader = train_loader
        self.checkpoint_frequency = checkpoint_frequency
        self.global_step = 0

    async def train_with_recovery(
        self,
        num_epochs: int,
        checkpoint_dir: str
    ):
        """Training with automatic checkpoint and recovery"""

        # Try to recover from latest checkpoint
        start_epoch, start_step = self._load_latest_checkpoint(checkpoint_dir)

        for epoch in range(start_epoch, num_epochs):
            try:
                await self._train_epoch(epoch, start_step if epoch == start_epoch else 0)

            except Exception as e:
                self._log_error(f"Training failed at epoch {epoch}: {e}")

                # Save emergency checkpoint
                self._save_checkpoint(
                    epoch,
                    self.global_step,
                    f"{checkpoint_dir}/emergency_{epoch}_{self.global_step}.pt"
                )

                raise

    async def _train_epoch(self, epoch: int, start_step: int):
        """Train one epoch with checkpointing"""

        for step, batch in enumerate(self.train_loader):
            if step < start_step:
                continue

            # Training step
            await self._train_step(batch)

            self.global_step += 1

            # Periodic checkpointing
            if self.global_step % self.checkpoint_frequency == 0:
                self._save_checkpoint(
                    epoch,
                    self.global_step,
                    f"checkpoint_e{epoch}_s{self.global_step}.pt"
                )

    def _load_latest_checkpoint(
        self,
        checkpoint_dir: str
    ) -> tuple:
        """Load latest checkpoint if exists"""

        checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint_*.pt")

        if not checkpoints:
            return 0, 0

        # Find latest checkpoint
        latest = max(checkpoints, key=os.path.getctime)

        checkpoint = torch.load(latest, map_location=f'cuda:{self.rank}')

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        return checkpoint['epoch'], checkpoint['step']

Conclusion

Distributed training is essential for modern AI development. Data parallelism handles most use cases, while pipeline and tensor parallelism enable training of massive models. Communication optimization, gradient accumulation, and fault tolerance ensure efficient, reliable training at scale.

The key is matching the parallelization strategy to your model size, data size, and available hardware. As models continue to grow, mastering distributed training becomes increasingly important.