Flash Attention 2
The seminal July 2023 paper
FlashAttention-2, a sequel to the original FlashAttention demonstrates a process that can significantly boost the efficiency of Transformer models.
With Transformers being at the heart of generative AI models for text, code, audio, and video - their scalability is paramount.
However, the attention mechanism, a critical component of Transformers, has been a bottleneck due to its quadratic memory and runtime increase with sequence length.
Therefore the main goal is to address the quadratic runtime and memory requirements of the attention layer, which hinders the scalability of Transformers to longer sequences.
FlashAttention-2, developed by the genius Tri Dao from the departments of Computer Science at Princeton University and Stanford University, tackled this issue head-on.
Interview with the genius inventor
The Leap from FlashAttention to FlashAttention-2
FlashAttention made waves by reordering the computation of attention to leverage the GPU memory hierarchy more effectively, significantly speeding up the process and reducing memory usage from quadratic to linear in sequence length.
This yielded a 2-4 times speedup over optimised baselines and up to 10-20 times memory savings, with no approximation.
Despite its success, FlashAttention still fell short of reaching the efficiency of optimised General Matrix Multiply (GEMM) operations, achieving only 25-40% of the theoretical maximum FLOPs/s.
Some definitions
Sequence length refers to the number of tokens in an input sequence. For example, in natural language processing tasks, the sequence length is the number of words or subwords in a given text input.
Batch Dimension: Data is often processed in batches to improve computational efficiency and convergence properties. The batch dimension refers to the number of samples or sequences that are processed together in a single forward or backward pass of the model.
Instead of processing one sequence at a time, multiple sequences are stacked together to form a batch. This allows for parallel computation and better utilisation of GPUs.
GEMM: GEMM stands for General Matrix Multiplication, which is a fundamental operation in linear algebra and is used in attention mechanisms.
GEMM involves multiplying two matrices (often denoted as and ) and adding the result to a third matrix ().
The standard formula for GEMM is:
where and are scalar coefficients (usually set to 1 and 0, respectively, for standard matrix multiplication).
GEMM and the Attention Mechanism
In the context of attention mechanisms, GEMM is used to compute the attention scores () and the attention output (), where:
is the query matrix
is the key matrix
is the value matrix
is the attention probability matrix (obtained by applying softmax to the attention scores)
The dimensions of these matrices are:
, with dimensions , where is the sequence length and is the head dimension
with dimensions
Output , with dimensions
In this notation:
The symbol "∈" is the mathematical notation for "belongs to" or "is an element of." In this context, it is used to indicate that a matrix belongs to a specific set of dimensions.
represents the set of real numbers
and represent the dimensions of the matrices, with being the number of rows and or being the number of columns, depending on the matrix
The shortfalls of the standard attention mechanism
The standard attention mechanism is a fundamental component of Transformer models.
However, as highlighted, it suffers from a significant limitation: its inability to scale efficiently as the sequence length (the number of tokens or words in the input) increases.
Attention Mechanism - how does it work?
Given an input sequence of length and a head dimension , the goal of the attention mechanism is to compute the attention output , which is an matrix.
The attention mechanism then operates as follows:
Compute the attention scores S
where and are the query and key matrices, respectively, both of size .
The attention scores is an matrix that represents the similarity or relevance between each pair of tokens in the input sequence.
Apply the softmax function to the attention scores S row-wise
The softmax function normalises the attention scores, converting them into attention probabilities which is also an matrix.
Each row of sums up to 1, representing the distribution of attention weights for a specific token.
Compute the attention output
where is the value matrix of size .
The attention output is an matrix that represents the weighted sum of the value vectors, with the weights determined by the attention probabilities .
Forward and Backward Pass
In the forward pass, the input sequences are processed in parallel across the batch dimension - multiple sequences are processed simultaneously - and the attention computation is performed independently for each head.
This parallelisation enables efficient utilisation of hardware resources and faster computation.
The backward pass, which is crucial for training the model, involves computing the gradients of the input sequences with respect to the output gradients.
Let be the gradient of the attention output with respect to some loss function.
By applying the chain rule (backpropagation), we can compute the gradients as follows:
Compute the gradient of :
Compute the gradient of :
Compute the gradient of using the gradient of the softmax function:
Compute the gradient of :
Compute the gradient of :
The function denotes the gradient of the softmax function applied row-wise.
It can be shown that if for some vectors and , then with the output gradient , the input gradient is given by:
where is a diagonal matrix with the elements of on its diagonal.
This formula captures the relationship between the input and output gradients of the softmax function, considering the interdependence of the output probabilities.
Memory Limitations
In standard attention implementations, the attention scores matrix and the attention probabilities matrix are materialised and stored in the high-bandwidth memory (HBM) of the GPU.
The memory requirement for storing these matrices is proportional to , where is the sequence length.
This quadratic memory requirement becomes a significant limitation when dealing with long sequences.
In typical applications, the sequence length can range from 1,000 to 8,000 tokens, while the head dimension is usually much smaller, around 64-128.
As a result, the memory consumption grows rapidly with increasing sequence length, making it challenging to process long sequences efficiently.
To put this into perspective, let's consider an example where = 5,000 and = 64.
The attention scores matrix and the attention probabilities matrix would each require storing 5,000 × 5,000 = 25,000,000 elements.
Assuming each element is a 32-bit floating-point number, the total memory requirement for both matrices would be approximately 200 MB.
This memory consumption can quickly become prohibitive, especially when dealing with larger sequence lengths or multiple attention heads.
Conclusion
The standard attention mechanism, while powerful, suffers from a quadratic memory requirement that limits its scalability to long sequences.
This limitation arises from the need to materialise and store the attention scores and probabilities matrices, which grow quadratically with the sequence length.
To address this challenge, various alternative attention mechanisms have been proposed, such as sparse attention, linearised attention, and hierarchical attention.
These mechanisms have aimed to reduce the memory footprint and computational complexity of the attention mechanism while still capturing the essential dependencies in the input sequence.
Flash Attention and Flash Attention 2 to date are the best solutions to the quadratic problem.
Flash Attention 2
FlashAttention-2 is an improved version of the FlashAttention algorithm. It aims to further optimise the attention mechanism for longer sequences by reducing memory usage and improving computational efficiency.
FlashAttention-2 addresses the quadratic memory requirement issue of the standard attention mechanism by employing several techniques to reduce memory reads/writes and improve parallelism.
Tiling and Block-wise Computation
FlashAttention-2 employs a technique called tiling and block-wise computation to reduce memory input/output (I/O) operations.
This is achieved by dividing the input matrices and output matrix into smaller blocks.
The purpose of this division is to ensure that these smaller blocks can fit into the GPU's on-chip SRAM (Static Random Access Memory), which is much faster than the main high-bandwidth memory (HBM).
By operating on these smaller blocks, FlashAttention-2 minimises the need for frequent data transfers between the main memory and the GPU, thereby reducing the memory I/O bottleneck and improving overall performance.
Online Softmax Computation
FlashAttention-2 uses the online softmax technique to enable block-wise computation of attention.
Instead of computing the softmax function over the entire attention scores matrix , FlashAttention-2 computes the softmax function for each block independently.
It maintains a running maximum and a running sum of exponentials for each block, which allows for the correct computation of attention probabilities without materialising the entire and matrices.
The online softmax technique ensures that the output of each block is correctly scaled, resulting in the same output as the standard attention mechanism, but with reduced memory requirements.
Recomputation and Gradient Checkpointing
In the backward pass, FlashAttention-2 avoids storing the large intermediate matrices and by recomputing their values once the blocks of inputs are loaded into on-chip SRAM.
By not having to save the and matrices, which have a size of (where is the sequence length), FlashAttention-2 achieves a significant memory saving, reducing the memory requirement from quadratic to linear in sequence length.
The recomputation strategy, also known as gradient checkpointing, trades off some additional computation for reduced memory usage, enabling the processing of longer sequences with limited memory resources.
Improved Parallelism
FlashAttention-2 enhances parallelism by processing different blocks of the attention computation on different thread blocks of the GPU.
In the forward pass, FlashAttention-2 parallelises the computation across the sequence length dimension, in addition to the batch and head dimensions. Each thread block handles a block of rows of the attention matrix, allowing for efficient utilisation of GPU resources, especially when dealing with long sequences and small batch sizes.
In the backward pass, FlashAttention-2 parallelises the computation across the sequence length dimension as well, with each thread block handling a block of columns of the attention matrix. This parallelisation scheme improves the utilisation of GPU resources and reduces the overall runtime.
Optimised Work Partitioning
FlashAttention-2 optimises the work partitioning between different warps within each thread block to reduce communication and shared memory reads/writes.
In the forward pass, FlashAttention-2 splits the query matrix across different warps, while keeping the key and value matrices accessible by all warps. This approach eliminates the need for communication between warps and reduces shared memory reads/writes compared to the "split-K" scheme used in the original FlashAttention.
In the backward pass, FlashAttention-2 also avoids the "split-K" scheme and partitions the work between warps to minimise shared memory reads/writes, leading to improved performance.
What is the "Split K" fix
In the context of attention computation, "split-K" refers to a scheme where the key matrix is split across different warps (groups of threads) within a thread block.
In the original FlashAttention implementation, the key matrix and value matrix are divided and assigned to different warps, while the query matrix is accessible by all warps.
After each warp performs a matrix multiplication between its assigned portion of and to compute a slice of the attention scores , the warps need to communicate and exchange intermediate results through shared memory.
This communication and shared memory reads/writes introduce overhead and can slow down the computation.
FlashAttention-2 addresses this issue by using a different work partitioning scheme in the forward pass, where the query matrix is split across warps instead of the key matrix .
Empirical Validation and Performance Gains
The FlashAttention-2 paper provides detailed benchmarks and performance comparisons that demonstrate the significant improvements achieved by the optimised attention mechanism.
Attention Computation Speed
FlashAttention-2 showcases remarkable speed improvements in attention computation compared to previous implementations.
The paper presents comprehensive benchmarks on NVIDIA A100 GPUs, comparing FlashAttention-2 with PyTorch's standard attention implementation, the original FlashAttention, and FlashAttention implemented in the Triton language.
FlashAttention-2 achieves up to 3 times faster attention computation compared to the original FlashAttention and up to 10 times faster than PyTorch's standard attention implementation.
For example, on sequences of length 2048 with a head dimension of 128, FlashAttention-2 attains an impressive 187 TFLOPs/s in the forward pass, surpassing FlashAttention's 120 TFLOPs/s and PyTorch's 60 TFLOPs/s.
These speed improvements are consistent across various sequence lengths and head dimensions, showcasing FlashAttention-2's ability to efficiently utilise GPU resources and minimise memory bottlenecks.
Backward Pass Efficiency
The backward pass in attention mechanisms is known to be computationally challenging due to the dependencies between gradients and the need to store intermediate values.
However, FlashAttention-2 tackles this challenge head-on and achieves significant speedups in the backward pass as well. The paper presents detailed benchmarks comparing the backward pass performance of FlashAttention-2 with other implementations.
On A100 GPUs, FlashAttention-2 demonstrates up to 2 times faster backward pass compared to the original FlashAttention
It also outperforms PyTorch's standard attention implementation by a significant margin.
For instance, on sequences of length 4096 with a head dimension of 128, FlashAttention-2 achieves 165 TFLOPs/s in the backward pass, while FlashAttention reaches 80 TFLOPs/s and PyTorch's implementation obtains only 45 TFLOPs/s.
This showcases FlashAttention-2's ability to efficiently compute gradients and optimise memory usage in the backward pass.
The implications of FlashAttention-2's performance improvements are far-reaching. It enables faster training and inference of Transformer models, opening up possibilities for handling longer sequences and more complex tasks.
The Impact of FlashAttention-2
FlashAttention-2 represented a significant advancement in the field of efficient Transformer models, particularly in addressing the limitations of the standard attention mechanism.
FlashAttention-2 achieves remarkable improvements in both speed and memory efficiency.
The empirical results demonstrate that FlashAttention-2 can significantly outperform previous implementations, including the original FlashAttention and standard attention implementations in PyTorch.
The implications of FlashAttention-2 extend beyond just raw performance gains. It opens up new possibilities for training and deploying large Transformer models capable of handling longer sequences and more complex tasks.
As the demand for AI models that can understand and generate longer contexts continues to grow, FlashAttention-2 provides a promising solution to address the challenges associated with scaling attention mechanisms.
Paper References
Ainslie et al. (2023) and Shazeer (2019) discuss techniques for training and optimising Transformer models, such as multi-query attention and fast transformer decoding.
Several papers, including Beltagy et al. (2020), Katharopoulos et al. (2020), Kitaev et al. (2020), Roy et al. (2021), Wang et al. (2020), and Zaheer et al. (2020), propose alternative attention mechanisms to address the quadratic complexity of standard attention. These techniques include sparse attention, linearized attention, and hierarchical attention.
Chen et al. (2021) and Choromanski et al. (2020) introduce novel attention mechanisms, such as Scatterbrain and Performers, to improve the efficiency of Transformers.
Dao et al. (2022) propose the original FlashAttention algorithm, which serves as the foundation for FlashAttention-2.
Jia and Van Sandt (2021) and Jia et al. (2018) provide insights into the architecture and performance characteristics of NVIDIA GPUs, which are relevant for optimising attention computations.
Lefaudeux et al. (2022) introduce the xformers library, which includes implementations of various efficient attention mechanisms.
Milakov and Gimelshein (2018) and Rabe and Staats (2021) discuss techniques for optimising the softmax computation in attention mechanisms.
OpenAI (2023) presents the technical details of the GPT-4 model, which demonstrates the importance of scaling Transformers to longer sequences.
Shoeybi et al. (2019) introduce Megatron-LM, a large-scale language model trained using model parallelism, highlighting the need for efficient attention mechanisms in training billion-parameter models.
Tillet et al. (2019) propose Triton, an intermediate language and compiler for tiled neural network computations, which is used in some implementations of FlashAttention.
Vaswani et al. (2017) introduce the original Transformer architecture and the standard attention mechanism, which serves as the basis for subsequent research on efficient attention.
Last updated