Grouped query attention (GQA) is a method to increase the efficiency of the attention mechanism in transformer models, often used to enable faster inference from large language models (LLMs).
Ainslie et al conceived grouped query attention as an optimization of multi-head attention (MHA), the innovative self-attention algorithm introduced in the seminal 2017 “Attention is All You Need” paper that established transformer neural networks. More specifically, it was proposed as a generalization and more restrained application of multi-query attention (MQA), an earlier optimization of MHA.
Though standard multi-head attention catalyzed an evolutionary leap forward in machine learning, natural language processing (NLP) and generative AI, it’s extremely demanding on computational resources and memory bandwidth. As LLMs grew larger and more sophisticated, these memory usage requirements became a bottleneck on progress, especially for the autoregressive decoder-only LLMs used for text generation, summarization and other generative AI tasks.
Subsequent research focused on techniques to enhance or streamline multi-head attention. Some, such as flash attention and ring attention, improve the ways that the GPUs used to train and run models handle calculations and memory storage. Others, such as GQA and MQA, explored changes to the way that transformer architectures process tokens.
Grouped query attention aims to balance the tradeoffs between standard multi-head attention and multi-query attention. The former maximizes accuracy at the cost of increased memory bandwidth overhead and decreased speed. The latter maximizes speed and efficiency at the expense of accuracy.
To understand how grouped query attention optimizes transformer models, it’s important to first understand how multi-head attention works in general. Both GQA and MQA simply refine, rather than replace, the core methodology of MHA.
The driving force behind LLMs and other models that use the transformer architecture is self-attention, a mathematical framework for understanding the relationships between each of the different tokens in a sequence. Self-attention allows an LLM to interpret text data through not only static baseline definitions, but also the context provided by other words and phrases.
In autoregressive LLMs used for text generation, the attention mechanism helps the model predict the next token in a sequence by determining which previous tokens are most worth “paying attention to” at that moment. Information from tokens it deems most relevant is given greater attention weights, while information from tokens deemed irrelevant is given attention weights approaching 0.
The multi-head attention mechanism that animates transformer models generates rich contextual information through calculating self-attention many times in parallel by splitting attention layers into multiple attention heads.
The authors of “Attention is All You Need” articulated its attention mechanism using the terminology of a relational database: queries, keys and values. Relational databases are designed to simplify the storage and retrieval of relevant data: they assign a unique identifier (“key”) to each piece of data, and each key is associated with a corresponding value. The goal of a relational database is to match each query to the appropriate key.
For each token in a sequence, multi-head attention requires the creation of 3 vectors.
The mathematical interactions between these 3 vectors, mediated by the attention mechanism, are how a model adjusts its context-specific understanding of each token.
To generate each of these 3 vectors for a given token, the model starts with that token’s original vector embedding: a numerical encoding in which each dimension of the vector corresponds with some abstract element of the token’s semantic meaning. The number of dimensions in these vectors is a predetermined hyperparameter.
The Q, K and V vector for each token are generated by passing the original token embedding through a linear layer that precedes the first attention layer. This linear layer is partitioned into 3 unique matrices of model weights: WQ, WK and WV. The specific weight values therein are learned through self-supervised pretraining on a massive dataset of text examples.
Multiplying the token’s original vector embedding by WQ, WK and WV yields its corresponding query vector, key vector and value vector, respectively. The number of dimensions d each vector contains is determined by the size of each weight matrix. Q and K will have the same number of dimensions, dk.
These 3 vectors are then passed to the attention layer.
In the attention layer, the Q, K and V vectors are used to calculate an alignment score between each token at each position in a sequence. Those alignment scores are then normalized into attention weights using a softmax function.
For each token x in a sequence, alignment scores are calculated by computing the dot product of that token’s query vector Qx with the key vector K of each of the other tokens: in other words, by multiplying them together. If a meaningful relationship between 2 tokens is reflected in similarities between their respective vectors, multiplying them together will yield a large value. If the 2 vectors aren’t aligned, multiplying them together will yield a small or negative value. Most transformer models use a variant called scaled dot product attention, in which QK is scaled—that is, multiplied—by to improve training stability.
These query-key alignment scores are then typed in to a softmax function. Softmax normalizes all inputs to a value between 0 and 1 such that they all add up to 1. The outputs of the softmax function are the attention weights, each representing the share (out of 1) of token x’s attention to be paid to each of the other tokens. If a token’s attention weight is close to 0, it will be ignored. An attention weight of 1 would mean that a token receives x’s entire attention and all others will be ignored.
Finally, the value vector for each token is multiplied by its attention weight. These attention-weighted contributions from each previous token are averaged together and added to the original vector embedding for token x. With this, token x’s embedding is now updated to reflect the context provided by the other tokens in the sequence that are relevant to it.
The updated vector embedding is then sent to another linear layer, with its own weight matrix WZ, where the context-updated vector is normalized back to a consistent number of dimensions and then sent to the next attention layer. Each progressive attention layer captures greater contextual nuance.
Using the averages of attention-weighted contributions from other tokens instead of accounting for each piece of attention-weighted context individually is mathematically efficient, but it results in a loss of detail.
To compensate, the transformer networks split the original input token’s embedding into h evenly sized pieces. They likewise split WQ, WK and WV into h subsets called query heads, key head and value heads, respectively. Each query head, key head and value head receives a piece of the original token embedding. The vectors produced by each of these parallel triplets of query heads, key heads and value heads are fed into a corresponding attention head. Eventually, the outputs of these h parallel circuits are concatenated back together to update the full token embedding.
In training, each circuit learns distinct weights that capture a separate aspect of semantic meanings. This, in turn, helps the model process the different ways that a word’s implications can be influenced by the context of other words around it.
The downside of standard multi-head attention is not so much the presence of some crucial flaw, but rather the lack of any optimization. MHA was the first algorithm of its kind and represents the most complex execution of its general mechanism for attention computation.
Most of MHA’s inefficiency stems from the abundance of calculations and model parameters. In standard MHA, each query head, key head and value head in each attention block has its own matrix of weights. So, for instance, a model with 8 attention heads in each attention layer—far fewer than most modern LLMs—would require 24 unique weight matrices for the layer's Q, K and V heads alone. This entails a huge number of intermediate calculations at each layer.
One consequence of this configuration is that it’s computationally expensive. Compute requirements for MHA scale quadratically with respect to sequence length: doubling the number of tokens in an input sequence requires quadruple the complexity. This puts hard practical limits on the size of context windows.
MHA also puts a major strain on system memory. GPUs don’t have much on-board memory to store the outputs of the massive quantity of intermediate calculations that must be recalled at each subsequent processing step. These intermediate results are instead stored in high-bandwidth memory (HBM), which isn’t located on the GPU chip itself. This entails a small amount of latency each time keys and values must be read from memory. As transformer models began to scale to many billions of parameters, the time and compute required to train and run inference became a bottleneck on model performance.
Further progress required methods to reduce the number of computational steps without reducing the capacity of transformers to learn and reproduce intricately complex linguistic patterns. It was in this context that MQA, and subsequently GQA, were introduced.
Multi-query attention (MQA) is a more computationally efficient attention mechanism that simplifies multi-head attention to reduce memory usage and intermediate calculations. Instead of training a unique key head and value head for each attention head, MQA uses a single key head and single value head at each layer. Therefore, key vectors and value vectors are calculated only once; this single set of key and value vectors is then shared across all h attention heads.
This simplification greatly reduces the number of linear projections that the model must calculate and store in high-bandwidth memory. According to the 2019 paper that introduced MQA, MQA allows a 10–100 times smaller key-value pair storage (or KV cache) and 12 times faster decoder inference. MQA’s reduced memory usage also significantly speeds up training by enabling a larger batch size.
Despite its benefits, MQA comes with several unavoidable downsides.
Grouped query attention is a more general, flexible formulation of multi-query attention that partitions query heads into multiple groups that each share a set of keys and values, rather than sharing one set of keys and values across all query heads.
Following the publication of “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” in May 2023, many LLMs quickly adopted GQA. For instance, Meta first adopted GQA for its Llama 2 models in July 2023 and retained GQA in the Llama 3 models released in 2024. Mistral AI used GQA in the Mistral 7B model that it released in September 2023. Likewise, IBM’s Granite 3.0 models employ GQA for fast inference.
In theory, GQA can be thought of as a generalization of the spectrum between standard MHA and full MQA. GQA with the same number of key-value head groups as attention heads is the equivalent of standard MHA; GQA with 1 head group is the equivalent of MQA.
In practice, GQA almost always implies some intermediate approach, in which the number of groups is itself an important hyperparameter.
Grouped query attention offers several advantages that have led to its relatively widespread adoption for leading LLMs.
Learn how CEOs can balance the value generative AI can create against the investment it demands and the risks it introduces.
Learn fundamental concepts and build your skills with hands-on labs, courses, guided projects, trials and more.
Learn how to confidently incorporate generative AI and machine learning into your business.
Want to get a better return on your AI investments? Learn how scaling gen AI in key areas drives change by helping your best minds build and deliver innovative new solutions.
Train, validate, tune and deploy generative AI, foundation models and machine learning capabilities with IBM watsonx.ai, a next-generation enterprise studio for AI builders. Build AI applications in a fraction of the time with a fraction of the data.
Put AI to work in your business with IBM's industry-leading AI expertise and portfolio of solutions at your side.
Reinvent critical workflows and operations by adding AI to maximize experiences, real-time decision-making and business value.