Page cover

Flash Attention 2

The seminal July 2023 paper

FlashAttention-2, a sequel to the original FlashAttention demonstrates a process that can significantly boost the efficiency of Transformer models.

With Transformers being at the heart of generative AI models for text, code, audio, and video - their scalability is paramount.

However, the attention mechanism, a critical component of Transformers, has been a bottleneck due to its quadratic memory and runtime increase with sequence length.

Therefore the main goal is to address the quadratic runtime and memory requirements of the attention layer, which hinders the scalability of Transformers to longer sequences.

FlashAttention-2, developed by the genius Tri Dao from the departments of Computer Science at Princeton University and Stanford University, tackled this issue head-on.

FlashAttention-2, developed by Tri Dao from the departments of Computer Science at Princeton University and Stanford University

Interview with the genius inventor

Great interview with the inventor of Flash Attention 1 and 2
A summary of the video transcript with Tri Dao

Tri Dao, the inventor of Flash Attention and Flash Attention 2, shares his insights on the development of these techniques, the current state of AI research, and his perspectives on the future of the field. Let's break down the key points from the discussion.

Insights and Ideas behind Flash Attention

  • Dao and his team focused on making attention more memory-efficient without approximations, as most practitioners preferred exact methods over approximate ones.

  • They realised that the majority of time in attention computation was spent on memory reading and writing, not floating-point operations (FLOPs).

  • By combining ideas from the systems side (kernel fusion, tiling) and the machine learning side (the online softmax trick), they developed Flash Attention, which reduces memory usage from quadratic to linear in sequence length.

  • Flash Attention provides a 2-4x speedup compared to standard attention implementation, allowing for training with longer sequences at the same cost.

Development of Flash Attention 2

  • Dao started working on Flash Attention 2 as an excuse to play with NVIDIA's CUTLASS library, which provides primitives for efficient GPU operations.

  • He realised there were still inefficiencies in Flash Attention and made tweaks to reduce non-matmul FLOPs, as modern GPUs have specialised units (e.g., Tensor Cores) that make matmul much faster than non-matmul operations.

  • Flash Attention 2 also parallelises the computation across different thread blocks and partitions the work between warps to reduce shared memory access.

  • These optimisations result in a 2x speedup compared to Flash Attention, approaching the efficiency of highly-optimized matrix multiplication routines.

Explanation of GPU Performance Characteristics

  • GPUs have a memory hierarchy consisting of high-bandwidth memory (HBM) and on-chip SRAM (shared memory).

  • HBM is larger but slower, while SRAM is much smaller but faster.

  • The goal is to design algorithms that take advantage of this asymmetric memory hierarchy, loading data from HBM to SRAM, performing computations, and then writing results back to HBM.

  • As sequence lengths grow, the memory required for attention matrices increases quadratically, while the number of attention heads remains relatively small, leading to a memory inefficiency that Flash Attention aims to address.

Role of Academia and Industry in AI Research

  • Academia and industry play complementary roles in AI research.

  • Industry has an advantage in scaling models and providing infrastructure, while academia focuses on developing new ideas, understanding and explaining models, and taking riskier bets on novel architectures.

  • Dao believes it's important for researchers to maintain a balance between exploring fundamentals and keeping up with the latest developments in the field.

Transformer Alternatives and Future Directions

  • Dao is interested in exploring transformer alternatives, such as recurrent neural networks (RNNs) and state space methods, which could be more efficient for long sequences and high-throughput generation.

  • He believes these alternatives could have advantages in specific use cases, such as processing high-resolution images, audio, and video.

  • However, he also acknowledges the "hardware lottery" and the challenges of developing new architectures that can compete with the heavily-optimised transformer on current hardware and software frameworks.

The Leap from FlashAttention to FlashAttention-2

FlashAttention made waves by reordering the computation of attention to leverage the GPU memory hierarchy more effectively, significantly speeding up the process and reducing memory usage from quadratic to linear in sequence length.

This yielded a 2-4 times speedup over optimised baselines and up to 10-20 times memory savings, with no approximation.

Despite its success, FlashAttention still fell short of reaching the efficiency of optimised General Matrix Multiply (GEMM) operations, achieving only 25-40% of the theoretical maximum FLOPs/s.

Refresher: GPU Performance Characteristics

Compute elements

GPUs consist of compute elements, such as floating-point arithmetic units, which perform the actual computations. These units are designed to handle parallel processing tasks efficiently.

Specialised matrix multiply units

Modern GPUs, like Nvidia's Tensor Cores, have specialised units that accelerate matrix multiplication operations in low-precision formats, such as FP16 (half-precision floating-point) or BF16 (bfloat16).

These units are optimised for fast and efficient matrix multiplication, which is a common operation in deep learning workloads.

Memory hierarchy

GPUs have a memory hierarchy that consists of two main components:

High Bandwidth Memory (HBM): HBM is a type of high-performance memory that offers high bandwidth and capacity. For example, the A100 GPU has 40-80GB of HBM with a bandwidth of 1.5-2.0TB/s. HBM is used to store large datasets and intermediate results during computation.

On-chip SRAM (Shared Memory): GPUs also have on-chip SRAM, known as shared memory, which is a fast but limited-capacity memory. In the A100 GPU, each of the 108 streaming multiprocessors (SMs) has 192KB of shared memory with an estimated bandwidth of around 19TB/s. Shared memory is used for fast communication and data sharing between threads within a thread block.

L2 cache: GPUs also have an L2 cache, which is a larger but slower cache compared to shared memory. However, the L2 cache is not directly controllable by the programmer and is managed by the GPU's cache hierarchy.

Execution Model

Threads: GPUs employ a massive number of threads to execute operations in parallel. These threads are lightweight and can be quickly switched to hide memory latency and maximise utilisation of the GPU's compute resources.

Thread blocks: Threads are organized into thread blocks, which are scheduled to run on the GPU's streaming multiprocessors (SMs). Each thread block contains a set of threads that can cooperate and share resources, such as shared memory.

Warps: Within each thread block, threads are further grouped into warps, typically consisting of 32 threads. Threads within a warp execute in lockstep, meaning they perform the same instruction on different data elements simultaneously (SIMD - Single Instruction, Multiple Data). Warps are the basic unit of execution on the GPU.

Communication and synchronization

Shuffle instructions: Threads within a warp can communicate using fast shuffle instructions, which allow them to exchange data directly without going through shared memory. This is useful for efficient data sharing and reduction operations within a warp.

Shared memory: Warps within a thread block can communicate by reading from and writing to shared memory. Shared memory acts as a fast, software-managed cache that enables efficient data sharing and coordination between warps.

Kernel execution: A GPU kernel is a function that is executed in parallel by multiple threads on the GPU. Each kernel loads inputs from HBM into registers and shared memory, performs computations, and then writes the outputs back to HBM. The kernel is launched with a specified grid and block size, which determines the number of thread blocks and threads per block.

Understanding these GPU performance characteristics and the execution model optimises GPU code and achieves high performance.

By leveraging the specialised matrix multiply units, using the memory hierarchy effectively, and organising threads and thread blocks appropriately, developers can maximise the utilisation of GPU resources and minimise memory latency, ultimately leading to faster execution of parallel workloads.

Some definitions

Sequence length refers to the number of tokens in an input sequence. For example, in natural language processing tasks, the sequence length is the number of words or subwords in a given text input.

Batch Dimension: Data is often processed in batches to improve computational efficiency and convergence properties. The batch dimension refers to the number of samples or sequences that are processed together in a single forward or backward pass of the model.

Instead of processing one sequence at a time, multiple sequences are stacked together to form a batch. This allows for parallel computation and better utilisation of GPUs.

GEMM: GEMM stands for General Matrix Multiplication, which is a fundamental operation in linear algebra and is used in attention mechanisms.

GEMM involves multiplying two matrices (often denoted as AA and BB) and adding the result to a third matrix (CC).

The standard formula for GEMM is:

C=αAB+βCC = \alpha AB + \beta C

where α\alpha and β\beta are scalar coefficients (usually set to 1 and 0, respectively, for standard matrix multiplication).

GEMM and the Attention Mechanism

In the context of attention mechanisms, GEMM is used to compute the attention scores (QKQK^⊤) and the attention output (PVPV), where:

  • QQ is the query matrix

  • KK is the key matrix

  • VV is the value matrix

  • PP is the attention probability matrix (obtained by applying softmax to the attention scores)

The dimensions of these matrices are:

  • Q,K,VRQ, K, V ∈ ℝ, with dimensions (N,d)(N, d), where NN is the sequence length and dd is the head dimension

  • PRP ∈ ℝ with dimensions (N,N) (N, N)

  • Output ORO ∈ ℝ, with dimensions (N×d) (N × d)

In this notation:

  1. The symbol "∈" is the mathematical notation for "belongs to" or "is an element of." In this context, it is used to indicate that a matrix belongs to a specific set of dimensions.

  2. R represents the set of real numbers

  3. (N,d)(N, d) and (N,N)(N, N) represent the dimensions of the matrices, with NN being the number of rows and dd or NN being the number of columns, depending on the matrix

The shortfalls of the standard attention mechanism

The standard attention mechanism is a fundamental component of Transformer models.

However, as highlighted, it suffers from a significant limitation: its inability to scale efficiently as the sequence length (the number of tokens or words in the input) increases.

Attention Mechanism - how does it work?

Given an input sequence of length NN and a head dimension dd, the goal of the attention mechanism is to compute the attention output OO, which is an N×d N × d matrix.

The attention mechanism then operates as follows:

Compute the attention scores S

S=QKS = QK^⊤ where QQ and KK are the query and key matrices, respectively, both of size N×dN × d.

The attention scores SS is an N×NN × N matrix that represents the similarity or relevance between each pair of tokens in the input sequence.

Apply the softmax function to the attention scores S row-wise

P=softmax(S)P = softmax(S)

The softmax function normalises the attention scores, converting them into attention probabilities PP which is also an N×NN × N matrix.

Each row of PP sums up to 1, representing the distribution of attention weights for a specific token.

Compute the attention output OO

O=PVO = PV where VV is the value matrix of size N×dN × d.

The attention output OO is an N×dN × d matrix that represents the weighted sum of the value vectors, with the weights determined by the attention probabilities PP.

Forward and Backward Pass

In the forward pass, the input sequences are processed in parallel across the batch dimension - multiple sequences are processed simultaneously - and the attention computation is performed independently for each head.

This parallelisation enables efficient utilisation of hardware resources and faster computation.

The backward pass, which is crucial for training the model, involves computing the gradients of the input sequences with respect to the output gradients.

Let dOdO be the gradient of the attention output OO with respect to some loss function.

By applying the chain rule (backpropagation), we can compute the gradients as follows:

  1. Compute the gradient of VV: dV=PdO dV = P^⊤ dO

  2. Compute the gradient of PP: dP=dOVdP = dO V^⊤

  3. Compute the gradient of SS using the gradient of the softmax function: dS=dsoftmax(dP)dS = dsoftmax(dP)

  4. Compute the gradient of QQ: dQ=dSKdQ = dS K

  5. Compute the gradient of KK: dK=QdSdK = Q^⊤ dS

The dsoftmaxdsoftmax function denotes the gradient of the softmax function applied row-wise.

It can be shown that if p=softmax(s)p = softmax(s) for some vectors ss and pp, then with the output gradient dpdp, the input gradient dsds is given by:

ds=(diag(p)pp)dpds = (diag(p) - pp^⊤) dp

where diag(p)diag(p) is a diagonal matrix with the elements of pp on its diagonal.

This formula captures the relationship between the input and output gradients of the softmax function, considering the interdependence of the output probabilities.

Memory Limitations

In standard attention implementations, the attention scores matrix SS and the attention probabilities matrix PP are materialised and stored in the high-bandwidth memory (HBM) of the GPU.

The memory requirement for storing these matrices is proportional to N2N^2, where NN is the sequence length.

This quadratic memory requirement becomes a significant limitation when dealing with long sequences.

In typical applications, the sequence length NN can range from 1,000 to 8,000 tokens, while the head dimension dd is usually much smaller, around 64-128.

As a result, the memory consumption grows rapidly with increasing sequence length, making it challenging to process long sequences efficiently.

To put this into perspective, let's consider an example where NN = 5,000 and dd = 64.

The attention scores matrix SS and the attention probabilities matrix PP would each require storing 5,000 × 5,000 = 25,000,000 elements.

Assuming each element is a 32-bit floating-point number, the total memory requirement for both matrices would be approximately 200 MB.

This memory consumption can quickly become prohibitive, especially when dealing with larger sequence lengths or multiple attention heads.

From the above image you can see that SRAM is faster than the HBM, but SRAM has very much limited memory as its fast cache memory. In the above image you can compare the speed of HBM and SRAM. HBM is used to store Tensors, while SRAM is used to perform compute operations on those Tensors.

Conclusion

The standard attention mechanism, while powerful, suffers from a quadratic memory requirement that limits its scalability to long sequences.

This limitation arises from the need to materialise and store the attention scores and probabilities matrices, which grow quadratically with the sequence length.

To address this challenge, various alternative attention mechanisms have been proposed, such as sparse attention, linearised attention, and hierarchical attention.

These mechanisms have aimed to reduce the memory footprint and computational complexity of the attention mechanism while still capturing the essential dependencies in the input sequence.

Flash Attention and Flash Attention 2 to date are the best solutions to the quadratic problem.

The problem with the sequence length being larger than the head dimension

The significance of N being much larger than d lies in the memory complexity of the standard attention implementation.

Since the attention scores matrix S and the attention probabilities matrix P have dimensions N×N, the memory required to store these matrices grows quadratically with the sequence length N.

For example, if N is 1,000 (1k), the matrices would have 1,000,000 (1 million) elements, and if N is 8,000 (8k), the matrices would have 64,000,000 (64 million) elements!

In contrast, the head dimension d is typically much smaller than N.

Even with a head dimension of 128, which is on the higher end of the typical range, it is still significantly smaller than the sequence lengths encountered in practice.

This difference in scale between N and d highlights the memory inefficiency of the standard attention implementation.

As the sequence length N grows, the memory required to store the attention matrices increases quadratically, while the head dimension d remains relatively small.

This leads to a bottleneck in terms of memory consumption, especially when dealing with long sequences, making it challenging to process such sequences efficiently using the standard attention implementation.

Flash Attention 2

FlashAttention-2 is an improved version of the FlashAttention algorithm. It aims to further optimise the attention mechanism for longer sequences by reducing memory usage and improving computational efficiency.

FlashAttention-2 addresses the quadratic memory requirement issue of the standard attention mechanism by employing several techniques to reduce memory reads/writes and improve parallelism.

Tiling and Block-wise Computation

FlashAttention-2 employs a technique called tiling and block-wise computation to reduce memory input/output (I/O) operations.

This is achieved by dividing the input matrices (Q,K,V)(Q, K, V) and output matrix (O)(O) into smaller blocks.

The purpose of this division is to ensure that these smaller blocks can fit into the GPU's on-chip SRAM (Static Random Access Memory), which is much faster than the main high-bandwidth memory (HBM).

By operating on these smaller blocks, FlashAttention-2 minimises the need for frequent data transfers between the main memory and the GPU, thereby reducing the memory I/O bottleneck and improving overall performance.

You can see the concept here - tiling means breaking or chopping the the matrix into multiple block and then performing the computation in SRAM instead of passing a complete N*N matrix from HBM to SRAM. This is because SRAM is very limited in capacity - the entire matrix cannot be loaded into SRAM.

Online Softmax Computation

  • FlashAttention-2 uses the online softmax technique to enable block-wise computation of attention.

  • Instead of computing the softmax function over the entire attention scores matrix (S)(S), FlashAttention-2 computes the softmax function for each block independently.

  • It maintains a running maximum and a running sum of exponentials for each block, which allows for the correct computation of attention probabilities without materialising the entire SS and PP matrices.

  • The online softmax technique ensures that the output of each block is correctly scaled, resulting in the same output as the standard attention mechanism, but with reduced memory requirements.

Recomputation and Gradient Checkpointing

  • In the backward pass, FlashAttention-2 avoids storing the large intermediate matrices SS and PP by recomputing their values once the blocks of inputs (Q,K,V)(Q, K, V) are loaded into on-chip SRAM.

  • By not having to save the SS and PP matrices, which have a size of N×NN×N (where NN is the sequence length), FlashAttention-2 achieves a significant memory saving, reducing the memory requirement from quadratic to linear in sequence length.

  • The recomputation strategy, also known as gradient checkpointing, trades off some additional computation for reduced memory usage, enabling the processing of longer sequences with limited memory resources.

Improved Parallelism

  • FlashAttention-2 enhances parallelism by processing different blocks of the attention computation on different thread blocks of the GPU.

  • In the forward pass, FlashAttention-2 parallelises the computation across the sequence length dimension, in addition to the batch and head dimensions. Each thread block handles a block of rows of the attention matrix, allowing for efficient utilisation of GPU resources, especially when dealing with long sequences and small batch sizes.

  • In the backward pass, FlashAttention-2 parallelises the computation across the sequence length dimension as well, with each thread block handling a block of columns of the attention matrix. This parallelisation scheme improves the utilisation of GPU resources and reduces the overall runtime.

From the paper: In the forward pass (left), we parallelise the workers (thread blocks) where each worker takes care of a block of rows of the attention matrix. In the backward pass (right), each worker takes care of a block of columns of the attention matrix

Optimised Work Partitioning

  • FlashAttention-2 optimises the work partitioning between different warps within each thread block to reduce communication and shared memory reads/writes.

  • In the forward pass, FlashAttention-2 splits the query matrix (Q)(Q) across different warps, while keeping the key (K) (K) and value (V) (V) matrices accessible by all warps. This approach eliminates the need for communication between warps and reduces shared memory reads/writes compared to the "split-K" scheme used in the original FlashAttention.

  • In the backward pass, FlashAttention-2 also avoids the "split-K" scheme and partitions the work between warps to minimise shared memory reads/writes, leading to improved performance.

What is the "Split K" fix

  • In the context of attention computation, "split-K" refers to a scheme where the key matrix (K)(K) is split across different warps (groups of threads) within a thread block.

  • In the original FlashAttention implementation, the key matrix (K)(K) and value matrix (V)(V) are divided and assigned to different warps, while the query matrix (Q) (Q) is accessible by all warps.

  • After each warp performs a matrix multiplication between its assigned portion of QQ and KK to compute a slice of the attention scores (QK)(QK^⊤), the warps need to communicate and exchange intermediate results through shared memory.

  • This communication and shared memory reads/writes introduce overhead and can slow down the computation.

  • FlashAttention-2 addresses this issue by using a different work partitioning scheme in the forward pass, where the query matrix QQis split across warps instead of the key matrix KK.

Work partitioning between different warps in the forward pass

Empirical Validation and Performance Gains

The FlashAttention-2 paper provides detailed benchmarks and performance comparisons that demonstrate the significant improvements achieved by the optimised attention mechanism.

Attention Computation Speed

FlashAttention-2 showcases remarkable speed improvements in attention computation compared to previous implementations.

The paper presents comprehensive benchmarks on NVIDIA A100 GPUs, comparing FlashAttention-2 with PyTorch's standard attention implementation, the original FlashAttention, and FlashAttention implemented in the Triton language.

FlashAttention-2 achieves up to 3 times faster attention computation compared to the original FlashAttention and up to 10 times faster than PyTorch's standard attention implementation.

For example, on sequences of length 2048 with a head dimension of 128, FlashAttention-2 attains an impressive 187 TFLOPs/s in the forward pass, surpassing FlashAttention's 120 TFLOPs/s and PyTorch's 60 TFLOPs/s.

These speed improvements are consistent across various sequence lengths and head dimensions, showcasing FlashAttention-2's ability to efficiently utilise GPU resources and minimise memory bottlenecks.

Backward Pass Efficiency

The backward pass in attention mechanisms is known to be computationally challenging due to the dependencies between gradients and the need to store intermediate values.

However, FlashAttention-2 tackles this challenge head-on and achieves significant speedups in the backward pass as well. The paper presents detailed benchmarks comparing the backward pass performance of FlashAttention-2 with other implementations.

On A100 GPUs, FlashAttention-2 demonstrates up to 2 times faster backward pass compared to the original FlashAttention

It also outperforms PyTorch's standard attention implementation by a significant margin.

For instance, on sequences of length 4096 with a head dimension of 128, FlashAttention-2 achieves 165 TFLOPs/s in the backward pass, while FlashAttention reaches 80 TFLOPs/s and PyTorch's implementation obtains only 45 TFLOPs/s.

This showcases FlashAttention-2's ability to efficiently compute gradients and optimise memory usage in the backward pass.

The implications of FlashAttention-2's performance improvements are far-reaching. It enables faster training and inference of Transformer models, opening up possibilities for handling longer sequences and more complex tasks.

The Impact of FlashAttention-2

FlashAttention-2 represented a significant advancement in the field of efficient Transformer models, particularly in addressing the limitations of the standard attention mechanism.

FlashAttention-2 achieves remarkable improvements in both speed and memory efficiency.

The empirical results demonstrate that FlashAttention-2 can significantly outperform previous implementations, including the original FlashAttention and standard attention implementations in PyTorch.

The implications of FlashAttention-2 extend beyond just raw performance gains. It opens up new possibilities for training and deploying large Transformer models capable of handling longer sequences and more complex tasks.

As the demand for AI models that can understand and generate longer contexts continues to grow, FlashAttention-2 provides a promising solution to address the challenges associated with scaling attention mechanisms.

Paper References

Ainslie et al. (2023) and Shazeer (2019) discuss techniques for training and optimising Transformer models, such as multi-query attention and fast transformer decoding.

Several papers, including Beltagy et al. (2020), Katharopoulos et al. (2020), Kitaev et al. (2020), Roy et al. (2021), Wang et al. (2020), and Zaheer et al. (2020), propose alternative attention mechanisms to address the quadratic complexity of standard attention. These techniques include sparse attention, linearized attention, and hierarchical attention.

Chen et al. (2021) and Choromanski et al. (2020) introduce novel attention mechanisms, such as Scatterbrain and Performers, to improve the efficiency of Transformers.

Dao et al. (2022) propose the original FlashAttention algorithm, which serves as the foundation for FlashAttention-2.

Jia and Van Sandt (2021) and Jia et al. (2018) provide insights into the architecture and performance characteristics of NVIDIA GPUs, which are relevant for optimising attention computations.

Lefaudeux et al. (2022) introduce the xformers library, which includes implementations of various efficient attention mechanisms.

Milakov and Gimelshein (2018) and Rabe and Staats (2021) discuss techniques for optimising the softmax computation in attention mechanisms.

OpenAI (2023) presents the technical details of the GPT-4 model, which demonstrates the importance of scaling Transformers to longer sequences.

Shoeybi et al. (2019) introduce Megatron-LM, a large-scale language model trained using model parallelism, highlighting the need for efficient attention mechanisms in training billion-parameter models.

Tillet et al. (2019) propose Triton, an intermediate language and compiler for tiled neural network computations, which is used in some implementations of FlashAttention.

Vaswani et al. (2017) introduce the original Transformer architecture and the standard attention mechanism, which serves as the basis for subsequent research on efficient attention.

Last updated

Was this helpful?