AIdventure - Speeding up Attention Layers

Speeding up Attention Layers

Speeding up Attention Layers
Mario Parreño#transformer#attention#optimization#cache

This post delves into the fascinating world of attention layers, exploring their evolution from the base multi-head attention to more recent variants like multi-query and grouped-query attention. We’ll examine why these newer architectures appeared and how they address the limitations of their predecessors.

As advance, the inference complexities, spatial and temporal, are the core of the improvements, focusing on how to reduce the cache needed without losing the performance of the model.

Multi-Head Attention

Let’s recall the Attention mechanism, the cornerstone of the Transformer architecture. The input sequence is projected into three different spaces: query, key, and value. Attention scores are computed by taking the dot product of the query and key vectors, then scaled, masked and softmaxed to obtain the attention weights. Finally, the output is the weighted dot product of the attention weights and the value vectors. The idea of the Attention mechanism is to model the relationships between the tokens in the input sequence.

The Multi-Head Attention extends this idea by computing multiple attention heads in parallel, splitting the query, key, and value vectors into nheadsn_heads different heads. If you notice, in a simple attention mechanism, when calculating the product between the attention weights and the value matrix, we are modeling relationships between tokens based on a single set of attention patterns. However, in Multi-Head Attention, we create multiple sets of queries, keys, and values through linear projections of the input. This allows us to apply the attention mechanism separately for each head, enabling the model to capture different aspects of the relationships between tokens simultaneously. Each head can focus on different features or patterns in the data, providing multiple perspectives on the input. The outputs from all heads are then combined, allowing the model to integrate information from these various viewpoints. This approach enhances the model’s ability to capture complex, multi-faceted relationships within the data.

Next is the Multi-Head Attention mechanism. The head_sizen_headshead\_size*n\_heads is usually named as the model dimension, hidden size, or embedding size. The head_sizehead\_size is the size of the hidden layer in each head, and n_headsn\_heads is the number of heads.

Multi-Head Attention mechanism.
Multi-Head Attention mechanism.

KV-Cache

Transformer models are huge and different optimization techniques appeared to speed up the inference. The objective of our model inference is to produce the next token that follows the last token in the intput sequence (next token prediction). What do we need to compute this? Just the final embedding of the last token, the last token representation of the Value Reweighted.

To obtain this embedding, last token representation of the Value Reweighted, we need to compute the dot product between the Attention Scores and the Value Vectors.

The Value Layer Weights are allways the same, as we are at inference, then Value Vectors for the tokens previous to the last token are the same and we just need to compute the new last one. We can cache Value Vectors from pevious tokens and reuse them.

Regarding the Attention Scores, for obtaining the last token representation of the Value Reweighted we just need the Attention Scores of the last token. For calculating this final attention row, we just need the Query Vectors of the last token and make the dot product with all the Key Vectors. But the Key Vectors from previous tokens are the same than when were computed in the previous step, then we can cache Key Vectors from pevious tokens and reuse them.

And with this last token embedding and the Key and Value vectors cached, we can iterate over the Transformer blocks and compute the final embedding of the last token representation.

Multi-Head Attention mechanism. KV-Cache.
Multi-Head Attention mechanism. KV-Cache.

¿Why don’t we cache the Query Vectors? Because we don’t need all the Query Vectors, just the last one. For calculating the last token Value Reweighted we just need the Attention Scores of the last token, and for that one we just need the last Query Vector.

Caching the Key and Value Vectors, we can save computations and speed up the inference process. The cost of caching those vectors is l×b×n×h×s×2×2l \times b \times n \times h \times s \times 2 \times 2, where the final terms 2×22 \times 2 correspond to the number of caches per layer (o.e. one for Key and one for Value), and the number of bytes per floating point value (assuming each model parameter is a 16-bit floating point value), respectively. The rest of the terms are:

KV-Cache hyperparameters to take into account.

HyperparameterDescription
llNumber of Transformer layers in the model.
bbBatch size.
nnNumber of attention heads.
hhHead hidden size.
ssSequence length.

Can we further reduce the cache needed? Yes, we can! By reducing the size of the cache, improving the load and therefore the speed of inference and also the memory required. That is, at the cost of some of the performance of our model.

Multi-Query Attention

If we look closer at the Multi-Head Attention mechanism, we can see that the Key and Value Vectors have different set of weights for each head. How can we reduce the cache needed for the Key and Value Vectors? Simple, we can reduce the number of Key and Value heads needed. Particularly, Multi-Query Attention proposes to use a single set of Key and Value Vectors. This obviously will reduce the capacity of the model to capture different aspects of the relationships between tokens, we have less heads to model the relationships, and therefore the performance of the model will be affected. But, it will reduce the cache needed and the inference time.

Multi-Query Attention mechanism.
Multi-Query Attention mechanism.

The cost of caching the Key and Value Vectors is l×b×h×s×2×2l \times b \times h \times s \times 2 \times 2. Note how the number of attention heads nn is not present in the formula.

Inneficient Parallelization

Many large Transformer models require more memory than can fit on a single GPU, employing model parallelism to distribute the model across multiple GPUs. This is accomplished usually with Tensor Parallelism, assigning different attention heads to different GPUs.

For example, 175B GPT-3 model has 96 attention heads, and if we have 8 GPUs, we can assign 12 heads to each GPU.

Each GPU performs the entire attention computation for its assigned heads, then all GPUs exchange the results of their computations. With Multi-Query Attention, each attention head has the same set of Key and Value Vectors, computing and caching the same set of Key and Value Vectors in each GPU. This is inneficient, as we are computing the same values in each GPU at the same time we are reducing the capacity of the model.

Grouped-Query Attention

Grouped-Query Attention tries to solve the inneficient parallelization of Multi-Query Attention. To do so, it proposes to increase the number of Key and Value Vectors heads to an arbitrary number, an hyperparameter GG. The Query Vectors are grouped in GG groups, and each group has a different set of Key and Value Vectors.

Multi-Group Attention mechanism.
Multi-Group Attention mechanism.

When operating in a multi-GPU environment with tensor parallelism, we can essentially get the same speed up as Multi-Query Attention by setting GG equal to the number of GPUs, but with the advatange of providing the model with more capacity and performance. The cost of caching the Key and Value Vectors is l×b×G×h×s×2×2l \times b \times G \times h \times s \times 2 \times 2. Note how the number of attention heads nn is not present in the formula, being replaced by the hyperparameter GG.

Speed ups references

To put in context the speed ups of the different attention mechanisms, we can see the following table:

Speed ups of the different attention mechanisms.

ModelParameter CountKV Cache Size MHAKV Cache Size MHAKV Cache Size MGA
GPT-3175B4.5GB48MB384MB

Credits

Table of Contents