Training large language models (LLMs) like GPT-4 requires the use of distributed computing patterns as there is a need to work with vast amounts of data while training with LLMs having multi-billion parameters vis-a-vis limited GPU support (NVIDIA A100 with 80 GB currently) for LLM training. In this blog, we will delve deep into some of the most important distributed LLM training patterns such as distributed data parallel (DDP) and Fully sharded data parallel (FSDP). The primary difference between these patterns is based on how the model is split or sharded across GPUs in the system. You might want to check out greater details in this book: Generative AI on AWS.
In the DDP computing pattern, a single copy of the entire model is loaded into each GPU in the distributed computing environment. This means that if you have multiple GPUs, each one will contain a complete copy of the model. Often, the technique called quantization is used to load a single copy of the model in each GPU. Quantization is a technique that reduces the precision of model parameters (like weights and activations) to lower bit representations, aiming to decrease memory usage and computational requirements while maintaining model performance. For example, the precision of model parameters can be reduced from 32-bit precision to 16-bit or 8-bit precision. That would cut down the GPU memory needs by half. Read this post for the details – LLM GPU memory requirements.
The following represents how LLM training happens based on DDP:
When implementing DDP, it is ensured that each GPU can fit not only the LLM parameters and data batches but also the additional data that is needed to fulfill the training loop, including optimizer states, activations, and temporary function variables. In DDP, each GPU has a full copy of everything needed to perform the forward and backward pass. If this data can not be stored on a single GPU, this is where another distributed pattern comes into the picture, termed FSDP – Fully Sharded Data Parallel.
You can use the PyTorch distributed RPC framework to combine distributed data parallelism (DDP) with distributed model parallelism to train LLM. Read greater detail on this page – Combining DDP with distributed RPC framework. Here is the source code for combining DDP with the Distributed RPC framework for demonstration purposes.
In the FSDP pattern, the model is sharded across multiple GPUs because the model is too large for a single GPU (based on DDP) even after quantization. FSDP is motivated by this paper – ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. ZeRO stands for zero redundancy optimizer. The idea of ZeRO is to reduce DDP’s data redundancy by sharding the model including its additional gradients, activations, and optimizer states across the GPUs to achieve zero redundancy in the system.
In FSDP, the sharding is accomplished in different stages depending on GPU shards. These stages include the following:
When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on the device. Read greater details on this PyTorch page – Getting started with Fully Sharded Data Parallel (FSDP). Great details can also be found in this paper – PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel.
FSDP requires dynamic reconstruction of the full layer from the sharded data onto each GPU before the forward and backward passes. Here are the key steps:
When to use FSDP vs DDP?
FSDP can scale model training for both small and large models across different GPU cluster sizes. For smaller LLMs, it is found that both DDP and FSDP perform at par with each other.
Last updated: 15th May, 2024 Have you ever wondered how your bank decides what to…
In this fast-changing world, the ability to learn effectively is more valuable than ever. Whether…
Last updated: 13th May, 2024 Whether you are a researcher, data analyst, or data scientist,…
Last updated: 12th May, 2024 Data lakehouses are a relatively new concept in the data…
Last updated: 12th May 2024 In this blog, we get an overview of the machine…
Last updated: 12th May, 2024 In the world of generative AI models, autoencoders (AE) and…