# Flash Attention 2

The seminal July 2023 paper

FlashAttention-2, a sequel to the original FlashAttention demonstrates a process that can significantly boost the efficiency of Transformer models.

With Transformers being at the heart of generative AI models for text, code, audio, and video - their **scalability is paramount. **

However, the **attention mechanism**, a critical component of Transformers, has been a bottleneck due to its * quadratic memory and runtime increase with sequence length*.

Therefore the main goal is to address the quadratic runtime and memory requirements of the attention layer, which hinders the scalability of Transformers to longer sequences.

FlashAttention-2, developed by the genius Tri Dao from the departments of Computer Science at Princeton University and Stanford University, tackled this issue head-on.

### Interview with the genius inventor

### The Leap from FlashAttention to FlashAttention-2

FlashAttention made waves by * reordering the computation of attention to leverage the GPU memory hierarchy more effectively,* significantly speeding up the process and reducing memory usage from quadratic to linear in sequence length.

This yielded a 2-4 times speedup over optimised baselines and up to 10-20 times memory savings, with no approximation.

Despite its success, FlashAttention still fell short of reaching the efficiency of optimised **General Matrix Multiply (GEMM)** operations, achieving only 25-40% of the theoretical maximum FLOPs/s.

### Some definitions

#### Sequence length refers to the number of tokens in an input sequence. For example, in natural language processing tasks, the sequence length is the number of words or subwords in a given text input.

**Batch Dimension: **Data is often processed in batches to improve computational efficiency and convergence properties. The **batch dimension** refers to the number of samples or sequences that are processed together in a single forward or backward pass of the model.

Instead of processing one sequence at a time, multiple sequences are stacked together to form a batch. This allows for parallel computation and better utilisation of GPUs.

**GEMM: **GEMM stands for **General Matrix Multiplication**, which is a fundamental operation in linear algebra and is used in attention mechanisms.

GEMM involves multiplying two** matrices** (often denoted as $A$ and $B$) and adding the result to a third matrix ($C$).

The standard formula for GEMM is:

where $\alpha$ and $\beta$ are **scalar coefficients** (usually set to 1 and 0, respectively, for standard matrix multiplication).

### GEMM and the Attention Mechanism

In the context of attention mechanisms, GEMM is used to compute the **attention scores** ($QK^⊤$) and the **attention output** ($PV$), where:

$Q$ is the

**query matrix**$K$ is the

**key matrix**$V$ is the

**value matrix**$P$ is the

**attention probability matrix**(obtained by**applying softmax**to the**attention scores**)

#### The **dimensions** of these matrices are:

**dimensions**of these matrices are:

$Q, K, V ∈ ℝ$, with

**dimensions**$(N, d)$, where $N$ is the**sequence length**and $d$ is the**head dimension**$P ∈ ℝ$ with

**dimensions**$(N, N)$**Output**$O ∈ ℝ$, with**dimensions**$(N × d)$

**In this notation:**

The

**symbol**"∈" is the mathematical notation for "belongs to" or "is an element of." In this context, it is used to indicate that a matrix belongs to a specific set of dimensions.$ℝ$ represents the

**set of real numbers**$(N, d)$ and $(N, N)$ represent the

**dimensions of the matrices**, with $N$being the**number of rows**and $d$ or $N$ being the**number of columns**, depending on the matrix

### The shortfalls of the standard attention mechanism

The **standard attention mechanism **is a fundamental component of Transformer models.

However, as highlighted, it suffers from a significant limitation: its* inability to scale* efficiently as the

**sequence length**(the number of tokens or words in the input) increases.

**Attention Mechanism - how does it work?**

**Attention Mechanism - how does it work?**

Given an **input sequence of length** $N$ and a **head dimension** $d$, the goal of the** attention mechanism** is to compute the **attention output** $O$, which is an $N × d$ **matrix**.

The **attention mechanism **then operates as follows:

Compute the** attention scores** S

$S = QK^⊤$ where $Q$ and $K$ are the **query** and **key** matrices, respectively, both of size $N × d$.

The **attention scores** $S$ is an $N × N$ **matrix** that represents the similarity or relevance between each pair of tokens in the **input sequence**.

Apply the **softmax function** to the **attention scores** S row-wise

$P = softmax(S)$

The **softmax function** normalises the** attention scores**, converting them into **attention probabilities** $P$ which is * also* an $N × N$

**matrix**.

Each **row** of $P$ sums up to **1**, representing the** distribution of attention weights** for a specific token.

Compute the **attention output** $O$

$O = PV$ where $V$ is the **value matrix** of size $N × d$.

The **attention output** $O$ is an $N × d$ **matrix** that represents the **weighted sum of the value vectors**, with the weights determined by the **attention probabilities** $P$.

**Forward and Backward Pass**

**Forward and Backward Pass**

In the **forward pass**, the** input sequences are processed in parallel **across the batch dimension - multiple sequences are processed simultaneously - and the **attention computation** is performed independently for each head.

This parallelisation enables efficient utilisation of hardware resources and faster computation.

The backward pass, which is crucial for training the model, involves **computing the gradients of the input sequences** with **respect to the output gradients**.

Let $dO$ be the **gradient of the attention output** $O$ with **respect to some loss function**.

By applying the **chain rule** (backpropagation), we can compute the **gradients **as follows:

Compute the gradient of $V$: $dV = P^⊤ dO$

Compute the gradient of $P$: $dP = dO V^⊤$

Compute the gradient of $S$ using the gradient of the softmax function: $dS = dsoftmax(dP)$

Compute the gradient of $Q$: $dQ = dS K$

Compute the gradient of $K$: $dK = Q^⊤ dS$

The $dsoftmax$ function denotes the** gradient** of the **softmax function** applied row-wise.

It can be shown that if $p = softmax(s)$ for some **vectors** $s$ and $p$, then with the **output gradient** $dp$, the **input gradient** $ds$ is given by:

$ds = (diag(p) - pp^⊤) dp$

where $diag(p)$ is a **diagonal matrix**** **with the **elements **of $p$ on its diagonal.

This formula captures the relationship between the input and output **gradients **of the **softmax function**, considering the interdependence of the output probabilities.

**Memory Limitations**

**Memory Limitations**

In standard attention implementations, the **attention scores matrix** $S$ and the **attention probabilities matrix** $P$are * materialised *and

**stored in the high-bandwidth memory**(HBM) of the GPU.

The memory requirement for storing these matrices is** proportional to** $N^2$, where $N$ is the **sequence length**.

This **quadratic memory requirement** becomes a significant limitation when dealing with long sequences.

In typical applications, the **sequence length** $N$ can **range from 1,000 to 8,000 tokens**, while the **head dimension** $d$ is usually **much smaller**, around 64-128.

As a result, the memory consumption grows rapidly with increasing **sequence length**, making it challenging to process long sequences efficiently.

To put this into perspective, let's consider an example where $N$ = 5,000 and $d$ = 64.

The **attention scores matrix** $S$ and the **attention probabilities matrix** $P$ would each require storing 5,000 × 5,000 = 25,000,000** elements**.

Assuming each element is a **32-bit floating-point number**, the total memory requirement for both matrices would be approximately **200 MB**.

This memory consumption can quickly become prohibitive, especially when dealing with larger sequence lengths or multiple attention heads.

**Conclusion**

**Conclusion**

The **standard attention mechanism**, while powerful, * suffers from a quadratic memory requirement *that limits its scalability to long sequences.

This limitation arises from the need to materialise and store the **attention scores** and probabilities matrices, which* grow quadratically* with the

**sequence length**.

To address this challenge, various alternative** attention mechanisms** have been proposed, such as sparse attention, linearised attention, and hierarchical attention.

These mechanisms have aimed to reduce the memory footprint and computational complexity of the attention mechanism while still capturing the essential dependencies in the input sequence.

Flash Attention and Flash Attention 2 to date are the best solutions to the quadratic problem.

### Flash Attention 2

**FlashAttention-2** is an **improved version** of the **FlashAttention** algorithm. It aims to further optimise the **attention mechanism **for longer **sequences **by reducing memory usage and improving computational efficiency.

FlashAttention-2 addresses the quadratic memory requirement issue of the standard** attention mechanism** by employing several techniques to reduce memory reads/writes and improve parallelism.

**Tiling and Block-wise Computation**

**Tiling and Block-wise Computation**

FlashAttention-2 employs a technique called** tiling and block-wise computation** to reduce memory input/output (I/O) operations.

This is achieved by dividing the** input matrices **$(Q, K, V)$and **output matrix** $(O)$ into smaller blocks.

The purpose of this division is to ensure that these smaller blocks can fit into the GPU's on-chip **SRAM (Static Random Access Memory)**, which is much faster than the main **high-bandwidth memory (HBM)**.

By operating on these smaller blocks, FlashAttention-2 minimises the need for frequent data transfers between the main memory and the GPU, thereby reducing the memory I/O bottleneck and improving overall performance.

**Online Softmax Computation**

**Online Softmax Computation**

FlashAttention-2 uses the

**online softmax technique**to enable block-wise computation of attention.Instead of computing the

**softmax function**over the entire**attention scores matrix**$(S)$, FlashAttention-2**computes the softmax function for each block independently**.It maintains a running maximum and a running sum of exponentials for each block, which allows for the correct computation of

**attention probabilities**without materialising the entire $S$ and $P$ matrices.The

**online softmax technique**ensures that the output of each block is correctly scaled, resulting in the same output as the standard**attention mechanism**, but with reduced memory requirements.

**Recomputation and Gradient Checkpointing**

**Recomputation and Gradient Checkpointing**

In the

**backward pass**, FlashAttention-2**avoids storing the large intermediate matrices**$S$ and $P$ by recomputing their values once the blocks of**inputs**$(Q, K, V)$ are loaded into on-chip SRAM.By not having to save the $S$ and $P$ matrices, which have a size of $N×N$ (where $N$ is the

**sequence length**), FlashAttention-2 achieves a significant memory saving,.**reducing the memory requirement from quadratic to linear in sequence length**The recomputation strategy, also known as

**gradient checkpointing**, trades off some additional computation for reduced memory usage, enabling the processing of longer**sequences**with limited memory resources.

**Improved Parallelism**

**Improved Parallelism**

FlashAttention-2

**enhances parallelism**by processing different blocks of the attention computation on**different thread blocks**of the GPU.In the

**forward pass****,**FlashAttention-2 parallelises the computation across the**sequence length**dimension,. Each thread block handles a block of rows of the**in addition to the batch and head dimensions****attention matrix**, allowing for efficient utilisation of GPU resources, especially when dealing with long**sequences**and small**batch sizes**.In the

**backward pass**, FlashAttention-2, with each thread block handling a block of columns of the**parallelises the computation across the sequence length dimension as well****attention matrix**. This parallelisation scheme improves the utilisation of GPU resources and reduces the overall runtime.

**Optimised Work Partitioning**

**Optimised Work Partitioning**

FlashAttention-2 optimises the work partitioning between different warps within each thread block to reduce communication and shared memory reads/writes.

In the

**forward pass**, FlashAttention-2 splits the**query matrix**$(Q)$ across different warps, while keeping the**key**$(K)$ and**value**$(V)$ matrices accessible by all warps. This approach eliminates the need for communication between warps and reduces shared memory reads/writes compared to the "split-K" scheme used in the original FlashAttention.In the

**backward pass**, FlashAttention-2 also avoids the "split-K" scheme and partitions the work between warps to minimise shared memory reads/writes, leading to improved performance.

**What is the "Split K" fix**

**What is the "Split K" fix**

In the context of

**attention computation**, "split-K" refers to a scheme where the**key matrix**$(K)$ is**split across different warps**(groups of threads) within a thread block.In the original FlashAttention implementation, the

**key matrix**$(K)$ and**value matrix**$(V)$ are divided and assigned to different warps, while the**query matrix**$(Q)$is accessible by all warps.After each warp performs a matrix multiplication between its assigned portion of $Q$ and $K$ to compute a slice of the

**attention scores**$(QK^⊤)$, the warps need to communicate and exchange intermediate results through shared memory.This communication and shared memory reads/writes introduce overhead and can slow down the computation.

**FlashAttention-2 addresses this issue**by using ain the forward pass, where the**different work partitioning scheme****query matrix**$Q$is split across warps instead of the**key matrix**$K$.

### Empirical Validation and Performance Gains

The FlashAttention-2 paper provides detailed benchmarks and performance comparisons that demonstrate the significant improvements achieved by the optimised **attention mechanism**.

#### Attention Computation Speed

FlashAttention-2 showcases remarkable speed improvements in attention computation compared to previous implementations.

The paper presents comprehensive benchmarks on** NVIDIA A100 GPUs**, comparing FlashAttention-2 with PyTorch's standard attention implementation, the original FlashAttention, and FlashAttention implemented in the Triton language.

FlashAttention-2 achieves up to **3 times faster attention computation **compared to the original FlashAttention and up to * 10 times faster than PyTorch's standard attention implementation*.

For example, on **sequences of length ****2048** with a** head dimension of ****128**, FlashAttention-2 attains an impressive **187 TFLOPs/s** in the** forward pass**, surpassing FlashAttention's **120 TFLOPs/s** and PyTorch's **60 TFLOPs/s**.

These speed improvements are consistent across various **sequence lengths** and** head dimensions**, showcasing FlashAttention-2's ability to efficiently utilise GPU resources and minimise memory bottlenecks.

#### Backward Pass Efficiency

The **backward pass** in **attention mechanisms **is known to be computationally challenging due to the dependencies between **gradients** and the need to store intermediate values.

However, FlashAttention-2 tackles this challenge head-on and achieves significant speedups in the **backward pass** as well. The paper presents detailed benchmarks comparing the** backward pass **performance of FlashAttention-2 with other implementations.

On **A100 GPUs**, FlashAttention-2 demonstrates up to **2 times faster backward pass** compared to the original FlashAttention

It also outperforms PyTorch's standard attention implementation by a significant margin.

For instance, on **sequences of length** **4096** with a **head dimension** of **128**, FlashAttention-2 achieves **165 TFLOPs/s** in the **backward pass**, while FlashAttention reaches **80 TFLOPs/s** and PyTorch's implementation obtains only **45 TFLOPs/s**.

This showcases FlashAttention-2's ability to * efficiently compute gradients and optimise memory usage *in the

**backward pass**.

The implications of FlashAttention-2's performance improvements are far-reaching. It enables faster training and inference of Transformer models, opening up possibilities for handling longer** sequences **and more complex tasks.

### The Impact of FlashAttention-2

FlashAttention-2 represented a significant advancement in the field of efficient Transformer models, particularly in addressing the limitations of the standard **attention mechanism**.

FlashAttention-2 achieves remarkable improvements in both speed and memory efficiency.

The empirical results demonstrate that FlashAttention-2 can significantly outperform previous implementations, including the original FlashAttention and standard attention implementations in PyTorch.

The implications of FlashAttention-2 extend beyond just raw performance gains. It opens up new possibilities for training and deploying large Transformer models capable of handling longer **sequences **and more complex tasks.

As the demand for AI models that can understand and generate longer contexts continues to grow, FlashAttention-2 provides a promising solution to address the challenges associated with scaling attention mechanisms.

### Paper References

Ainslie et al. (2023) and Shazeer (2019) discuss techniques for training and optimising Transformer models, such as multi-query attention and fast transformer decoding.

Several papers, including Beltagy et al. (2020), Katharopoulos et al. (2020), Kitaev et al. (2020), Roy et al. (2021), Wang et al. (2020), and Zaheer et al. (2020), propose alternative attention mechanisms to address the quadratic complexity of standard attention. These techniques include sparse attention, linearized attention, and hierarchical attention.

Chen et al. (2021) and Choromanski et al. (2020) introduce novel attention mechanisms, such as Scatterbrain and Performers, to improve the efficiency of Transformers.

Dao et al. (2022) propose the original FlashAttention algorithm, which serves as the foundation for FlashAttention-2.

Jia and Van Sandt (2021) and Jia et al. (2018) provide insights into the architecture and performance characteristics of NVIDIA GPUs, which are relevant for optimising attention computations.

Lefaudeux et al. (2022) introduce the xformers library, which includes implementations of various efficient attention mechanisms.

Milakov and Gimelshein (2018) and Rabe and Staats (2021) discuss techniques for optimising the softmax computation in attention mechanisms.

OpenAI (2023) presents the technical details of the GPT-4 model, which demonstrates the importance of scaling Transformers to longer sequences.

Shoeybi et al. (2019) introduce Megatron-LM, a large-scale language model trained using model parallelism, highlighting the need for efficient attention mechanisms in training billion-parameter models.

Tillet et al. (2019) propose Triton, an intermediate language and compiler for tiled neural network computations, which is used in some implementations of FlashAttention.

Vaswani et al. (2017) introduce the original Transformer architecture and the standard attention mechanism, which serves as the basis for subsequent research on efficient attention.

Last updated