What is the optimal framework and configuration for hosting large language models (LLMs) for text-generating generative AI applications? Despite the abundance of options for serving LLMs, this is a hard question to answer due to the size of the models, varying model architectures, performance requirements of applications, and more.
The Amazon SageMaker Large Model Inference (LMI) container makes it straightforward to serve LLMs by bringing together a host of different frameworks and techniques that optimize the deployment of LLMs. The LMI container has a powerful serving stack called DJL serving that is agnostic to the underlying LLM. It provides system-level configuration parameters that can be tuned for extracting the best performance of the hosting infrastructure for a given LLM. It also has support for recent optimizations like continuous batching, also known as iterative batching or rolling batching, which provides significant improvements in throughput.
In an earlier post, we showed how you can use the LMI container to deploy the Falcon family of models on SageMaker. In this post, we demonstrate how to improve the throughput and latency of serving Falcon-40B with techniques like continuous batching. We also provide an intuitive understanding of configuration parameters provided by the SageMaker LMI container that can help you find the best configuration for your real-world application.
Fundamentals of text-generative inference for LLMs
Let’s first look at a few fundamentals on how to perform inference for LLMs for text generation.
Forward pass, activations, and the KV cache
Given an input sequence of tokens, they are run in a forward pass across all the layers of the LLM (like Falcon) to generate the next token. A forward pass refers to the process of input data being passed through a neural network to produce an output. In the case of text generation, the forward pass involves feeding an initial seed or context into the language model and generating the next character or token in the sequence. To generate a sequence of text, the process is often done iteratively, meaning it is repeated for each step or position in the output sequence. At each iteration, the model generates the next character or token, which becomes part of the generated text, and this process continues until the desired length of text is generated.
Text generation with language models like Falcon or GPT are autoregressive. This means that the model generates one token at a time while conditioning on the previously generated tokens. In other words, at each iteration, the model takes the previously generated text as input and predicts the next token based on that context. As mentioned in vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention, in this autoregressive decoding process, all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to generate next tokens. These cached key and value tensors are often referred to as the KV cache.
Prefill and decode phases
In an autoregressive decoding process, like the one used in text generation with language models such as Falcon, there are typically two main phases: the prefill phase and the decode phase. These phases are crucial for generating coherent and contextually relevant text.
The prefill phase includes the following:
- Initial context – The prefill phase begins with an initial context or seed text provided by the user. This initial context can be a sentence, a phrase, or even just a single word. It sets the starting point for text generation and provides context for what comes next.
- Model conditioning – The provided context is used to condition the language model. The model takes this context as input and generates the next token (word or character) in the sequence based on its understanding of the context.
- Token generation – The model generates one token at a time, predicting what should come next in the text. This token is appended to the context, effectively extending it.
- Iterative process – The process of generating tokens is repeated iteratively. At each step, the model generates a token while considering the updated context, which now includes the tokens generated in previous steps. The prefill phase continues until a predetermined stopping condition is met. This condition can be a maximum length for the generated text, a specific token that signals the end of the text, or any other criteria set by the user or the application.
The decode phase includes the following:
- Completion – After the prefill phase, you have a partially generated text that may be incomplete or cut off at some point. The decode phase is responsible for completing the text to make it coherent and grammatically correct.
- Continuation from the last token – In the decode phase, the model starts from the last token generated during the prefill phase. It uses this token as the initial context and generates the next token to continue the text.
- Iterative completion – Like in the prefill phase, the process of generating tokens is again iterative. The model generates one token at a time, conditioning on the preceding tokens in the sequence.
- Stopping condition – The decode phase also has a stopping condition, which might be the same as in the prefill phase, such as reaching a maximum length or encountering an end-of-text token. When this condition is met, the generation process stops.
The combination of the prefill and decode phases allows autoregressive models to generate text that builds on an initial context and produces coherent, contextually relevant, and contextually consistent sequences of text.
Refer to A Distributed Serving System for Transformer-Based Generative Models for a detailed explanation of the process.
Optimizing throughput using dynamic batching
So far, we’ve only talked about a single input. In practice, we expect to deal with multiple requests coming in randomly from the application clients for inference concurrently or in a staggered fashion. In the traditional way, basic batching can be used to increase the throughput and the utilization of the computing resources of the GPU. Batching is effectively combining the numerical representations of more than one request in a batch and performing parallel runs of the autoregressive forward passes. This intelligent batching is done at the serving side.
SageMaker LMI’s DJLServing server can be configured to batch together multiple requests to process them in parallel by setting the following parameters in serving.properties:
- max_batch_delay = 100 – The maximum delay for batch aggregation in milliseconds. The default value is 100 milliseconds.
- batch_size = 32 – The dynamic batch size. The default is 1. This basically shows that DJLServing will queue up requests for 100 milliseconds at a time or if the number of requests that are queued up are up to the batch_size specified, the batch will be scheduled to run to the backend for inference. This is known as dynamic batching. It’s dynamic because the batch size may change across batches depending on how many requests were added in that time duration.
However, because requests might have different characteristics, (for example, some requests might be of shape 20 tokens of input and 500 tokens of output, whereas others might be reversed, with 500 tokens of input but only 20 for output), some requests might complete processing faster than others in the same batch. This could result in underutilization of the GPU while waiting for all in-flight requests in the batch to complete its decode stage, even if there are additional requests waiting to be processed in the queue.
The following diagram illustrates this process:
Optimizing throughput using continuous batching
With continuous batching, also known as iterative or rolling batching, we take advantage of the differences between the prefill and decode stages. To activate continuous batching, DJServing provides the following additional configurations as per serving.properties:
- engine=MPI – We encourage you to use the MPI engine for continuous batching.
- option.rolling_batch=auto or lmi-dist – We recommend using auto because it will automatically pick the most appropriate rolling batch algorithm along with other optimizations in the future.
- option.max_rolling_batch_size=32 – This limits the number of concurrent requests. The default is 32.
With continuous batching, the serving stack (DJLServing) doesn’t wait for all in-flight requests in a batch to complete its decode stage. Rather, at logical breaks (at the end of one iteration in the decode stage), it pulls in additional requests that are waiting in…