Training large language models (LLMs) like GPT-4 requires the use of distributed computing patterns as there is a need to work with vast amounts of data while training with LLMs having multi-billion parameters vis-a-vis limited GPU support (NVIDIA A100 with 80 GB currently) for LLM training. In this blog, we will delve deep into some of the most important distributed LLM training patterns such as distributed data parallel (DDP) and Fully sharded data parallel (FSDP). The primary difference between these patterns is based on how the model is split or sharded across GPUs in the system. You might want to check out greater details in this book: Generative AI on AWS.
Distributed Data Parallel (DDP) Pattern
In the DDP computing pattern, a single copy of the entire model is loaded into each GPU in the distributed computing environment. This means that if you have multiple GPUs, each one will contain a complete copy of the model. Often, the technique called quantization is used to load a single copy of the model in each GPU. Quantization is a technique that reduces the precision of model parameters (like weights and activations) to lower bit representations, aiming to decrease memory usage and computational requirements while maintaining model performance. For example, the precision of model parameters can be reduced from 32-bit precision to 16-bit or 8-bit precision. That would cut down the GPU memory needs by half. Read this post for the details – LLM GPU memory requirements.
The following represents how LLM training happens based on DDP:
- Model replication: Initially, you have one instance (or copy) of your model. In DDP, this model is replicated across all available GPUs. This means that if you have, say, 4 GPUs, each of these GPUs will have its complete copy of the model.
- Parameter Consistency: Although the model is replicated, each copy has the same parameters initially. This consistency is crucial for starting the training process uniformly across all GPUs.
- LLM training based on parallel data processing in GPUs: Data is split into batches, and sent into the batches to each GPU in parallel. Once loaded, the data is processed in parallel in each GPU. The LLM training is achieved in parallel with this phase.
- Gradients Synchronization: Post LLM training, the results from each GPU (e.g., gradients) are combined (e.g., averaged).
- LLM Updation: each model (one per GPU) in the distributed computing environment is updated with the combined results and the process continues.
When implementing DDP, it is ensured that each GPU can fit not only the LLM parameters and data batches but also the additional data that is needed to fulfill the training loop, including optimizer states, activations, and temporary function variables. In DDP, each GPU has a full copy of everything needed to perform the forward and backward pass. If this data can not be stored on a single GPU, this is where another distributed pattern comes into the picture, termed FSDP – Fully Sharded Data Parallel.
You can use the PyTorch distributed RPC framework to combine distributed data parallelism (DDP) with distributed model parallelism to train LLM. Read greater detail on this page – Combining DDP with distributed RPC framework. Here is the source code for combining DDP with the Distributed RPC framework for demonstration purposes.
Fully Sharded Data Parallel (FSDP) Pattern
In the FSDP pattern, the model is sharded across multiple GPUs because the model is too large for a single GPU (based on DDP) even after quantization. FSDP is motivated by this paper – ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. ZeRO stands for zero redundancy optimizer. The idea of ZeRO is to reduce DDP’s data redundancy by sharding the model including its additional gradients, activations, and optimizer states across the GPUs to achieve zero redundancy in the system.
In FSDP, the sharding is accomplished in different stages depending on GPU shards. These stages include the following:
- In 1st stage, sharding is done only for optimizer states across the GPUs. It reduces memory footprint by 4X.
- Stage 2 shards both the optimizer states and gradients across the GPUs. It reduces memory footprint by 8X.
- Stage 3 shards all data points including the model parameters (weights), the optimizer states, and gradients across the GPUs. It reduces memory footprint up to N times where N is the number of GPUs.
When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on the device. Read greater details on this PyTorch page – Getting started with Fully Sharded Data Parallel (FSDP). Great details can also be found in this paper – PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel.
FSDP requires dynamic reconstruction of the full layer from the sharded data onto each GPU before the forward and backward passes. Here are the key steps:
- Breaking Down the Model: In FSDP, the large neural network model is broken down into smaller units. Each unit contains a portion of the model’s parameters.
- Sharded model parameters: Parameters within each unit are flattened (i.e., arranged in a simple linear structure) and then sharded. This means dividing these parameters across multiple GPUs. Each GPU holds only a portion of the entire model’s parameters, unlike in Distributed Data Parallel (DDP) where each GPU holds a complete copy of the model.
- On-demand data request: Before performing a forward pass (a step in the training process where the model makes predictions based on input data), each GPU needs access to the complete set of parameters required for that particular operation. Since each GPU only has a shard of the data, it requests the necessary additional data from other GPUs.
- Materializing Unsharded Data Locally: The requested data shards from other GPUs are then temporarily compiled or “materialized” on the requesting GPU. This creates a complete, unsharded set of data required for the operation, but only for the duration of that specific operation. This process typically occurs on a per-layer basis during model training. It means that for each layer of the neural network, the GPU will request, compile, and use the necessary data from other GPUs as needed.
- Releasing Unsharded Data Post Forward Pass: After the completion of the forward pass, where each GPU has temporarily compiled the complete set of data (unsharded) for its operations, this data is then released or sent back to the respective GPUs.
- Reverting to Sharded State: This step involves returning the data to its original sharded state. Each GPU had initially only a portion of the entire model’s parameters (sharded data), and after the operation is complete, the data is returned to this sharded configuration. This is done to free up GPU memory, which is crucial for efficiently managing resources, especially when dealing with large models.
- Backward Pass Processing: The backward pass, a phase where the model learns by adjusting its parameters based on error gradients, is then carried out. The reduced memory usage from reverting to sharded data allows for more efficient execution of this resource-intensive process.
- Gradient Synchronization Across GPUs: Similar to the DDP pattern, after the backward pass, FSDP synchronizes the gradients (the computed adjustments to model parameters) across all GPUs. This ensures that all shards of the model are updated consistently based on the learning from the entire dataset.
- Updating Model Parameters Across Shards: Finally, the model parameters across all the shards (which are distributed across different GPUs) are updated. This step involves adjusting the model’s parameters based on the synchronized gradients, ensuring that each shard is up-to-date and consistent with the others.
When to use FSDP vs DDP?
FSDP can scale model training for both small and large models across different GPU cluster sizes. For smaller LLMs, it is found that both DDP and FSDP perform at par with each other.
I have been recently working in the area of Data analytics including Data Science and Machine Learning / Deep Learning. I am also passionate about different technologies including programming languages such as Java/JEE, Javascript, Python, R, Julia, etc, and technologies such as Blockchain, mobile computing, cloud-native technologies, application security, cloud computing platforms, big data, etc. I would love to connect with you on
Linkedin.
Check out my latest book titled as
First Principles Thinking: Building winning products using first principles thinking.
Latest posts by Ajitesh Kumar
(see all) Ajitesh KumarI have been recently working in the area of Data analytics including Data Science and Machine Learning / Deep Learning. I am also passionate about different technologies including programming languages such as Java/JEE, Javascript, Python, R, Julia, etc, and technologies such as Blockchain, mobile computing, cloud-native technologies, application security, cloud computing platforms, big data, etc. I would love to connect with you on Linkedin. Check out my latest book titled as First Principles Thinking: Building winning products using first principles thinking.