Grouped Query Attention
Last updated
Copyright Continuum Labs - 2023
Last updated
This May 2023 paper presents an approach to improving the inference speed of large language models (LLMs) with autoregressive decoders, specifically targeting the memory bandwidth overhead issues inherent in Transformer models.
It introduces two key contributions aimed at balancing the trade-off between model quality and inference speed: the uptraining of existing multi-head language model checkpoints to use Multi-Query Attention (MQA) and the introduction of Grouped-Query Attention (GQA).
MQA is a technique that reduces memory bandwidth overhead by using multiple query heads but a single key and value head.
While it significantly speeds up decoder inference, MQA can lead to quality degradation and training instability.
Additionally, the need to train separate models for optimising either quality or inference speed can be impractical. Despite its drawbacks, MQA is recognised for its potential to sharply decrease the memory bandwidth required for loading keys and values during each decoding step.
The first contribution of this work is a recipe for uptraining existing multi-head attention (MHA) model checkpoints to use MQA, requiring only 5% of the original pre-training compute.
This uptraining process involves converting MHA models to MQA by pooling the projection matrices for key and value heads into single projection matrices, followed by additional pre-training.
This method offers a cost-effective way to achieve models that are fast like MQA models but maintain the high quality of MHA checkpoints.
The second contribution is GQA, which generalises MQA by using an intermediate number of key-value heads—more than one but fewer than the number of query heads.
GQA divides query heads into groups, with each group sharing a single key head and value head. This approach allows for a model that achieves close to the quality of MHA while nearing the inference speed of MQA.
GQA represents a favourable compromise by interpolating between the two extremes of MHA and MQA, offering a scalable solution that adjusts the balance between memory bandwidth and model quality.
The methodology for implementing these contributions involves two steps: converting existing MHA models to either MQA or GQA configurations and then uptraining these models with a fraction of the original training compute.
The paper demonstrates that models uptrained to use GQA achieve a quality close to that of MHA models but with an inference speed comparable to MQA models. This is particularly significant for larger models, where memory bandwidth overhead becomes a more pressing issue.
This work had profound implications for the field of AI and LLM development.
By providing a method to uptrain existing models to use MQA or GQA, the authors offered a pathway to enhance the efficiency of LLMs without sacrificing quality.
GQA, in particular, presents an optimised solution for larger models, addressing the scalability challenges faced by current LLM architectures.
In conclusion, this paper introduced a novel approach to improving the inference efficiency of LLMs through MQA and GQA.
It provided a practical solution to the challenges of memory bandwidth overhead in autoregressive decoder inference, proposing a method that balances the need for high-quality model outputs with the demand for faster inference speeds.