Scaling Deep Learning Training with MPMD Pipeline Parallelism
Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover
Conference on Machine Learning and Systems 2025 · Day 4 · Session 9: Parallel and Distributed Systems
In the rapidly evolving landscape of deep learning, the relentless growth in model size necessitates increasingly sophisticated strategies for distributed training. As models surpass the capacity of single devices or even single nodes, practitioners must leverage vast clusters of GPUs and other accelerators. While techniques like **tensor parallelism** and **fully sharded data parallelism (FSDP)** offer relatively straightforward implementations often expressible through **Single Program, Multiple Data (SPMD)** paradigms with collective operations, the general form of **pipeline parallelism** presents a far more intricate challenge. This talk, presented by Jong Lee from Nvidia at MLSys 2025, introduces **Jack's PP**, a novel and comprehensive solution designed to abstract away the complexities of pipeline parallelism within the JAX ecosystem.
AI review
Jack's PP is a legitimate piece of systems engineering — automatic communication inference, schedule-agnostic pipeline parallelism, and 98% scaling efficiency to 1024 GPUs are real claims that deserve attention. But this article is a polished summary, not a technical window into the work. The actual implementation details that would let you reason about tradeoffs — how the JAXpr transformation handles shared weights in practice, what failure modes look like when pipeline_yields is misplaced, what the multi-controller runtime actually changes — are gestured at rather than shown. The numbers…