PyTorch Fully Sharded Data Parallel (FSDP)
Last updated
Copyright Continuum Labs - 2023
Last updated
This September 2023 paper introduces PyTorch Fully Sharded Data Parallel (FSDP), an industry-grade solution for large model training that enables sharding model parameters across multiple devices.
FSDP is designed to provide a user-friendly experience, handle hardware heterogeneity, optimise resource utilisation, and efficiently manage memory.
FSDP divides a model into smaller units and shards the parameters within each unit.
Sharded parameters are communicated and recovered on-demand before computations and discarded afterwards.
This approach minimizes peak memory consumption by only materializing parameters from one unit at a time.
FSDP introduces deferred initialization, allowing users to create a model instance on a dummy device and record initialization operations.
The model is then initialized and sharded unit by unit by replaying the recorded operations on a real GPU device.
This technique provides a user experience similar to local training while effectively scaling large models.
FSDP offers configurable sharding strategies that can be customized to match the physical interconnect topology of the cluster.
This allows FSDP to handle hardware heterogeneity, such as high-bandwidth islands within machines and low-bandwidth mesh across machines.
Resource Utilisation
FSDP minimizes downtime caused by non-computational operations to ensure GPU devices remain fully utilised during distributed training.
It employs techniques like operation reordering and parameter prefetching to overlap communication with computation and eliminate execution bubbles.
FSDP optimizes memory usage by prudently restricting the amount of blocks allocated for in-flight unsharded parameters.
It can suspend CPU execution if necessary to prevent triggering frequent memory defragmentations, which can slow down training.
Evaluation and Results
FSDP was evaluated on various models, including language models and recommendation system models, using up to 512 80GB A100 GPUs.
Experiments showed that FSDP achieves similar performance to Distributed Data Parallel (DDP) on small models.
FSDP enables training significantly larger models with near-linear scalability in terms of TFLOPS.
FSDP is a beta feature as of PyTorch 2.0 release and has been battle-tested in both industrial and research applications.
It tackles challenges in large model training by providing a user-friendly experience, handling hardware heterogeneity, optimizing resource utilization, and efficiently managing memory.
DistributedDataParallel (DDP) is PyTorch's first end-to-end distributed training feature that follows the model replication approach.
DDP maintains a model replica on each device and synchronises gradients through AllReduce operations during the backward pass.
It overlaps gradient communication with backward computation to accelerate training.
However, DDP requires all model parameters, gradients, and optimiser states to fit in the memory of a single GPU device, limiting its ability to support large models.
As model sizes grow, partitioning the model into smaller components and distributing them across multiple devices becomes a viable solution.
Pipeline parallelism breaks a sequence of layers into stages and feeds inputs to different stages in a pipelined manner to optimize resource utilisation.
Tensor RPC provides a lower-level toolkit for executing arbitrary computations on remote devices.
While these techniques can scale large models across multiple devices, they either limit the model to a sequence of stages or require modifications to the model authoring code, which can be a significant obstacle to adoption.
Sharding model parameters can help reduce memory footprint and support models with sizes beyond the memory capacity of a single GPU device.
After sharding, each rank holds only a shard of the model parameters, preventing it from performing the same computations as local training.
To ensure correctness, the training process can employ one or both of the following techniques:
a. Perform computations with parameter shards and communicate activations accordingly, avoiding the need to fully materialize any parameter. However, communication appears in the critical path and cannot easily overlap with computations.
b. Perform the same computation as local training by communicating parameters on-demand before computations. This allows parameter communications to overlap with preceding computations, but requires the on-demand communicated parameters to be fully materialised and fit in the memory of a single GPU device.
FSDP falls into the second category of communicating parameters on-demand. Based on observations and experiments, this approach is sufficient to support the majority of large model applications today and in the near future. If the requirement of fully materialising each parameter unit on GPU becomes a blocker, the two techniques can be combined to support such use cases.
Imagine FSDP as a master chef tasked with preparing a grand feast for a large gathering.
The feast represents the training of a large model, and the dishes symbolise the model's parameters. The chef's kitchen is a single GPU device, which has limited space to work with.
To tackle this challenge, the master chef (FSDP) employs a clever strategy. Instead of trying to prepare all the dishes (parameters) at once, which would overwhelm the kitchen (GPU), the chef divides the menu (model) into smaller, manageable units. Each unit consists of a subset of dishes (layers) that can be prepared independently.
During the preparation process (forward and backward computation), the chef focuses on one unit at a time.
The chef gathers all the necessary ingredients (unsharded parameters) for the dishes in the current unit from the kitchen's storage and other sous-chefs (peer ranks). With everything in place, the chef prepares the dishes (performs local computations) for that unit. Once the unit is complete, the chef cleans up the kitchen by putting away the extra ingredients (freeing peer shards) to make space for the next unit.
Throughout the feast preparation, the chef keeps the pantry (optimizer states) organized by storing ingredients (parameters) in smaller, individual portions (shards). This allows the chef to efficiently manage the limited kitchen space.
The kitchen's capacity (memory requirements) depends on the size of the portioned ingredients (sharded model) and the space needed for the largest unit's preparation (fully-materialized FSDP unit).
In this metaphorical representation, FSDP is portrayed as a master chef who efficiently manages the preparation of a grand feast (training a large model) by dividing the menu (model) into smaller units, focusing on one unit at a time, and optimizing kitchen space usage (memory) and cooking processes (computations and communications) to create a successful event.
The sharding strategy in FSDP is a crucial element that determines memory footprint and communication overhead.
FSDP supports various sharding strategies, ranging from fully replicated to fully sharded, which are generalised using the sharding factor (𝐹).
The sharding factor represents the number of ranks over which parameters are sharded.
Lowest memory footprint but highest communication overhead.
FSDP organizes parameters into a FlatParameter, which coalesces communications and evenly shards them across ranks.
The flatten-concat-chunk algorithm minimizes padding and enables efficient AllGather and ReduceScatter operations.
Fine-grained FlatParameter construction reduces peak memory but may decrease throughput due to more collectives.
Combines sharding and replication, providing a trade-off between memory saving and throughput degradation.
Exploits datacentre locality by mapping the device mesh into the datacentre layout.
Reduces cross-host traffic and improves performance of AllReduce collectives.
Suitable for medium-sized models that are too large for full replication but not large enough to fully utilize memory with full sharding.
FlatParameter interoperates with PyTorch's autograd engine for correct gradient propagation and timely gradient reduction.
Original parameters are set as views into the unsharded FlatParameter before forward computation.
A gradient hook is registered to perform post-backward logic, including gradient reduction.
Overlapping communication and computation using separate CUDA streams.
Backward prefetching to issue the next AllGather before the current ReduceScatter.
Forward prefetching to issue the next AllGather before forward computation of the current FSDP unit.
Gradient accumulation with and without communication, trading off memory usage and communication.
FSDP also addresses memory management challenges related to PyTorch's CUDA caching allocator.
The caching allocator's behavior can lead to over-allocation and blocking cudaFrees and cudaMallocs when the CPU thread runs ahead of GPU execution.
FSDP introduces a rate limiter that intentionally blocks the CPU thread to ensure proper caching allocator block reuse, allowing at most two in-flight AllGathers to achieve communication and computation overlap while minimising memory fragmentation.