# Flash Attention 2

FlashAttention-2, a sequel to the original FlashAttention demonstrates a process that can significantly boost the efficiency of Transformer models.&#x20;

With Transformers being at the heart of generative AI models for text, code, audio, and video - their **scalability is paramount.**&#x20;

However, the <mark style="color:blue;">**attention mechanism**</mark>, a critical component of Transformers, has been a bottleneck due to its *<mark style="color:yellow;">**quadratic memory and runtime increase with sequence length**</mark>*.&#x20;

Therefore the main goal is to address the quadratic runtime and memory requirements of the attention layer, which hinders the scalability of Transformers to longer sequences.

FlashAttention-2, developed by the genius Tri Dao from the departments of Computer Science at Princeton University and Stanford University, tackled this issue head-on.

{% embed url="<https://arxiv.org/abs/2307.08691>" %}
FlashAttention-2, developed by Tri Dao from the departments of Computer Science at Princeton University and Stanford University
{% endembed %}

### <mark style="color:purple;">Interview with the genius inventor</mark>

{% embed url="<https://www.youtube.com/watch?v=J4-qZ6KBalk>" %}
Great interview with the inventor of Flash Attention 1 and 2
{% endembed %}

<details>

<summary><mark style="color:green;"><strong>A summary of the video transcript with Tri Dao</strong></mark></summary>

Tri Dao, the inventor of Flash Attention and Flash Attention 2, shares his insights on the development of these techniques, the current state of AI research, and his perspectives on the future of the field. Let's break down the key points from the discussion.

<mark style="color:green;">**Insights and Ideas behind Flash Attention**</mark>

* Dao and his team focused on making attention more memory-efficient without approximations, as most practitioners preferred exact methods over approximate ones.
* They realised that the majority of time in attention computation was spent on memory reading and writing, <mark style="color:yellow;">**not floating-point operations (FLOPs)**</mark>.
* By combining ideas from the systems side (kernel fusion, tiling) and the machine learning side (the online softmax trick), they developed Flash Attention, which *<mark style="color:yellow;">**reduces memory usage from quadratic to linear in sequence length**</mark>*.
* Flash Attention provides a 2-4x speedup compared to standard attention implementation, allowing for training with longer sequences at the same cost.

<mark style="color:green;">**Development of Flash Attention 2**</mark>

* Dao started working on Flash Attention 2 as an excuse to play with <mark style="color:blue;">**NVIDIA's CUTLASS library**</mark>, which provides primitives for efficient GPU operations.
* He realised there were still inefficiencies in Flash Attention and made tweaks to reduce non-matmul FLOPs, as modern GPUs have specialised units (e.g., Tensor Cores) that make matmul much faster than non-matmul operations.
* Flash Attention 2 also parallelises the computation across different thread blocks and partitions the work between warps to reduce shared memory access.
* These optimisations result in a 2x speedup compared to Flash Attention, approaching the efficiency of highly-optimized matrix multiplication routines.

<mark style="color:green;">**Explanation of GPU Performance Characteristics**</mark>

* GPUs have a memory hierarchy consisting of <mark style="color:blue;">**high-bandwidth memory (HBM)**</mark> and <mark style="color:blue;">**on-chip SRAM (shared memory)**</mark>.
* HBM is larger but slower, while SRAM is much smaller but faster.
* The goal is to design algorithms that take advantage of this asymmetric memory hierarchy, loading data from HBM to SRAM, performing computations, and then writing results back to HBM.
* As sequence lengths grow, the memory required for attention matrices increases quadratically, while the number of attention heads remains relatively small, leading to a memory inefficiency that Flash Attention aims to address.

<mark style="color:green;">**Role of Academia and Industry in AI Research**</mark>

* Academia and industry play complementary roles in AI research.
* Industry has an advantage in scaling models and providing infrastructure, while academia focuses on developing new ideas, understanding and explaining models, and taking riskier bets on novel architectures.
* Dao believes it's important for researchers to maintain a balance between exploring fundamentals and keeping up with the latest developments in the field.

<mark style="color:green;">**Transformer Alternatives and Future Directions**</mark>

* Dao is interested in *<mark style="color:yellow;">**exploring transformer alternatives**</mark>*, such as recurrent neural networks (RNNs) and <mark style="color:blue;">**state space methods**</mark>, which could be more efficient for long sequences and high-throughput generation.
* He believes these alternatives could have advantages in specific use cases, such as processing high-resolution images, audio, and video.
* However, he also acknowledges the "hardware lottery" and the challenges of developing new architectures that can compete with the heavily-optimised transformer on current hardware and software frameworks.

</details>

### <mark style="color:purple;">The Leap from FlashAttention to FlashAttention-2</mark>

FlashAttention made waves by *<mark style="color:yellow;">**reordering the computation of attention to leverage the GPU memory hierarchy more effectively,**</mark>* significantly speeding up the process and reducing memory usage from quadratic to linear in sequence length.&#x20;

This yielded a 2-4 times speedup over optimised baselines and up to 10-20 times memory savings, with no approximation. &#x20;

Despite its success, FlashAttention still fell short of reaching the efficiency of optimised <mark style="color:blue;">**General Matrix Multiply (GEMM)**</mark> operations, achieving only 25-40% of the theoretical maximum FLOPs/s.

<details>

<summary><mark style="color:purple;"><strong>Refresher:</strong></mark><strong> </strong><mark style="color:green;"><strong>GPU Performance Characteristics</strong></mark></summary>

<mark style="color:purple;">**Compute elements**</mark>

GPUs consist of compute elements, such as floating-point arithmetic units, which perform the actual computations. These units are designed to handle parallel processing tasks efficiently.

<mark style="color:purple;">**Specialised matrix multiply units**</mark>

Modern GPUs, like Nvidia's Tensor Cores, have specialised units that accelerate matrix multiplication operations in low-precision formats, such as FP16 (half-precision floating-point) or BF16 (bfloat16).&#x20;

These units are optimised for fast and efficient matrix multiplication, which is a common operation in deep learning workloads.

<mark style="color:purple;">**Memory hierarchy**</mark>

GPUs have a memory hierarchy that consists of two main components:

<mark style="color:blue;">**High Bandwidth Memory (HBM):**</mark> HBM is a type of high-performance memory that offers high bandwidth and capacity. For example, the A100 GPU has 40-80GB of HBM with a bandwidth of 1.5-2.0TB/s. HBM is used to store large datasets and intermediate results during computation.

<mark style="color:blue;">**On-chip SRAM (Shared Memory):**</mark> GPUs also have on-chip SRAM, known as shared memory, which is a fast but limited-capacity memory. In the A100 GPU, each of the 108 streaming multiprocessors (SMs) has 192KB of shared memory with an estimated bandwidth of around 19TB/s. Shared memory is used for fast communication and data sharing between threads within a thread block.

<mark style="color:blue;">**L2 cache:**</mark> GPUs also have an L2 cache, which is a larger but slower cache compared to shared memory. However, the L2 cache is not directly controllable by the programmer and is managed by the GPU's cache hierarchy.

<mark style="color:purple;">**Execution Model**</mark>

<mark style="color:blue;">**Threads:**</mark> GPUs employ a massive number of threads to execute operations in parallel. These threads are lightweight and can be quickly switched to hide memory latency and maximise utilisation of the GPU's compute resources.

<mark style="color:blue;">**Thread blocks:**</mark> Threads are organized into thread blocks, which are scheduled to run on the GPU's streaming multiprocessors (SMs). Each thread block contains a set of threads that can cooperate and share resources, such as shared memory.

<mark style="color:blue;">**Warps:**</mark> Within each thread block, threads are further grouped into warps, typically consisting of <mark style="color:yellow;">32 threads</mark>. Threads within a warp execute in lockstep, meaning they perform the same instruction on different data elements simultaneously (SIMD - Single Instruction, Multiple Data). Warps are the basic unit of execution on the GPU.

<mark style="color:green;">**Communication and synchronization**</mark>

<mark style="color:blue;">**Shuffle instructions:**</mark> Threads within a warp can communicate using fast shuffle instructions, which allow them to exchange data directly without going through shared memory. This is useful for efficient data sharing and reduction operations within a warp.&#x20;

<mark style="color:blue;">**Shared memory:**</mark> Warps within a thread block can communicate by reading from and writing to shared memory. Shared memory acts as a fast, software-managed cache that enables efficient data sharing and coordination between warps.

<mark style="color:blue;">**Kernel execution:**</mark> A GPU kernel is a function that is executed in parallel by multiple threads on the GPU. Each kernel loads inputs from HBM into registers and shared memory, performs computations, and then writes the outputs back to HBM. The kernel is launched with a specified grid and block size, which determines the number of thread blocks and threads per block.

Understanding these GPU performance characteristics and the execution model optimises GPU code and achieves high performance.&#x20;

By leveraging the specialised matrix multiply units, using the memory hierarchy effectively, and organising threads and thread blocks appropriately, developers can maximise the utilisation of GPU resources and minimise memory latency, ultimately leading to <mark style="color:yellow;">**f**</mark>*<mark style="color:yellow;">**aster execution of parallel workloads.**</mark>*

</details>

### <mark style="color:purple;">Some definitions</mark>

#### <mark style="color:blue;">Sequence length</mark> refers to the number of tokens in an input sequence. For example, in natural language processing tasks, the sequence length is the number of words or subwords in a given text input.&#x20;

<mark style="color:blue;">**Batch Dimension:**</mark> Data is often processed in batches to improve computational efficiency and convergence properties.  The <mark style="color:blue;">**batch dimension**</mark> refers to the number of samples or sequences that are processed together in a single forward or backward pass of the model.&#x20;

Instead of processing one sequence at a time, multiple sequences are stacked together to form a batch.  This allows for parallel computation and better utilisation of GPUs.&#x20;

<mark style="color:blue;">**GEMM:**</mark> GEMM stands for <mark style="color:blue;">**General Matrix Multiplication**</mark>, which is a fundamental operation in linear algebra and is used in attention mechanisms.&#x20;

GEMM involves multiplying two <mark style="color:blue;">**matrices**</mark> (often denoted as $$A$$ and $$B$$) and adding the result to a third matrix ($$C$$).

The standard formula for GEMM is:

$$
C = \alpha AB + \beta C
$$

where $$\alpha$$  and $$\beta$$ are <mark style="color:blue;">**scalar coefficients**</mark> (usually set to 1 and 0, respectively, for standard matrix multiplication).

### <mark style="color:purple;">GEMM and the Attention Mechanism</mark>

In the context of attention mechanisms, GEMM is used to compute the <mark style="color:blue;">**attention scores**</mark> ($$QK^⊤$$) and the <mark style="color:blue;">**attention output**</mark> ($$PV$$), where:

* $$Q$$ is the <mark style="color:blue;">**query matrix**</mark>
* $$K$$ is the <mark style="color:blue;">**key matrix**</mark>
* $$V$$ is the <mark style="color:blue;">**value matrix**</mark>
* $$P$$ is the <mark style="color:blue;">**attention probability matrix**</mark> (obtained by <mark style="color:yellow;">**applying softmax**</mark> to the <mark style="color:blue;">**attention scores**</mark>)

#### <mark style="color:green;">The</mark> <mark style="color:green;"></mark><mark style="color:green;">**dimensions**</mark> <mark style="color:green;"></mark><mark style="color:green;">of these matrices are:</mark>

* $$Q, K, V ∈ ℝ$$, with <mark style="color:blue;">**dimensions**</mark> $$(N, d)$$, where $$N$$ is the <mark style="color:blue;">**sequence length**</mark> and $$d$$ is the <mark style="color:blue;">**head dimension**</mark>
* $$P ∈ ℝ$$ with <mark style="color:blue;">**dimensions**</mark> $$(N, N)$$
* <mark style="color:blue;">**Output**</mark> $$O ∈ ℝ$$, with <mark style="color:blue;">**dimensions**</mark> $$(N × d)$$

<mark style="color:green;">**In this notation:**</mark>

1. The <mark style="color:yellow;">**symbol**</mark> "∈" is the mathematical notation for "belongs to" or "is an element of." In this context, it is used to indicate that a matrix belongs to a specific set of dimensions.
2. $$ℝ$$ represents the <mark style="color:yellow;">**set of real numbers**</mark>
3. $$(N, d)$$ and $$(N, N)$$ represent the <mark style="color:blue;">**dimensions of the matrices**</mark>, with $$N$$being the <mark style="color:blue;">**number of rows**</mark> and $$d$$ or $$N$$ being the <mark style="color:blue;">**number of columns**</mark>, depending on the matrix

<figure><img src="https://1839612753-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FpV8SlQaC976K9PPsjApL%2Fuploads%2FDQP8hHZs6a33ecDrE334%2Fimage.png?alt=media&#x26;token=0852c226-cb8c-4506-a026-735229686421" alt="" width="563"><figcaption></figcaption></figure>

### <mark style="color:purple;">The shortfalls  of the standard attention mechanism</mark>

The <mark style="color:blue;">**standard attention mechanism**</mark> is a fundamental component of Transformer models.

However, as highlighted, it suffers from a significant limitation: its *<mark style="color:yellow;">**inability to scale**</mark>* efficiently as the <mark style="color:blue;">**sequence length**</mark> (the number of tokens or words in the input) increases.&#x20;

#### <mark style="color:green;">**Attention Mechanism - how does it work?**</mark>

Given an <mark style="color:blue;">**input sequence of length**</mark> $$N$$ and a <mark style="color:blue;">**head dimension**</mark> $$d$$, the goal of the <mark style="color:blue;">**attention mechanism**</mark> is to compute the <mark style="color:blue;">**attention output**</mark> $$O$$, which is an $$N × d$$ <mark style="color:blue;">**matrix**</mark>.&#x20;

The <mark style="color:blue;">**attention mechanism**</mark> then operates as follows:

Compute the <mark style="color:blue;">**attention scores**</mark> S

$$S = QK^⊤$$ where $$Q$$ and $$K$$ are the <mark style="color:blue;">**query**</mark> and <mark style="color:blue;">**key**</mark> matrices, respectively, both of size $$N × d$$.

The <mark style="color:blue;">**attention scores**</mark> $$S$$ is an $$N × N$$ <mark style="color:blue;">**matrix**</mark> that represents the similarity or relevance between each pair of tokens in the <mark style="color:blue;">**input sequence**</mark>.

Apply the <mark style="color:blue;">**softmax function**</mark> to the <mark style="color:blue;">**attention scores**</mark> S row-wise

$$P = softmax(S)$$&#x20;

The <mark style="color:blue;">**softmax function**</mark> normalises the <mark style="color:blue;">**attention scores**</mark>, converting them into <mark style="color:blue;">**attention probabilities**</mark> $$P$$ which is *<mark style="color:yellow;">**also**</mark>* an $$N × N$$ <mark style="color:blue;">**matrix**</mark>.&#x20;

Each <mark style="color:blue;">**row**</mark> of $$P$$ sums up to <mark style="color:yellow;">**1**</mark>, representing the <mark style="color:blue;">**distribution of attention weights**</mark> for a specific token.

Compute the <mark style="color:blue;">**attention output**</mark> $$O$$

$$O = PV$$ where $$V$$ is the <mark style="color:blue;">**value matrix**</mark> of size $$N × d$$.&#x20;

The <mark style="color:blue;">**attention output**</mark> $$O$$ is an $$N × d$$ <mark style="color:blue;">**matrix**</mark> that represents the <mark style="color:yellow;">**weighted sum of the value vectors**</mark>, with the weights determined by the <mark style="color:blue;">**attention probabilities**</mark> $$P$$.

### <mark style="color:green;">**Forward and Backward Pass**</mark>

In the <mark style="color:blue;">**forward pass**</mark>, the <mark style="color:yellow;">**input sequences are processed in parallel**</mark> across the batch dimension - multiple sequences are processed simultaneously - and the <mark style="color:blue;">**attention computation**</mark> is performed independently for each head.&#x20;

This parallelisation enables efficient utilisation of hardware resources and faster computation.

The backward pass, which is crucial for training the model, involves <mark style="color:yellow;">**computing the gradients of the input sequences**</mark> with <mark style="color:yellow;">**respect to the output gradients**</mark>.&#x20;

Let $$dO$$ be the <mark style="color:blue;">**gradient of the attention output**</mark> $$O$$ with <mark style="color:yellow;">**respect to some loss function**</mark>.&#x20;

By applying the <mark style="color:blue;">**chain rule**</mark> (backpropagation), we can compute the <mark style="color:blue;">**gradients**</mark> as follows:

1. Compute the gradient of $$V$$: $$dV = P^⊤ dO$$
2. Compute the gradient of $$P$$: $$dP = dO V^⊤$$
3. Compute the gradient of $$S$$ using the gradient of the softmax function: $$dS = dsoftmax(dP)$$
4. Compute the gradient of $$Q$$: $$dQ = dS K$$
5. Compute the gradient of $$K$$: $$dK = Q^⊤ dS$$

The $$dsoftmax$$ function denotes the <mark style="color:blue;">**gradient**</mark> of the <mark style="color:blue;">**softmax function**</mark> applied row-wise.

It can be shown that if $$p = softmax(s)$$ for some <mark style="color:blue;">**vectors**</mark> $$s$$ and $$p$$, then with the <mark style="color:blue;">**output gradient**</mark> $$dp$$, the <mark style="color:blue;">**input gradient**</mark> $$ds$$ is given by:&#x20;

$$ds = (diag(p) - pp^⊤) dp$$

where $$diag(p)$$ is a <mark style="color:blue;">**diagonal matrix**</mark> with the <mark style="color:blue;">**elements**</mark> of $$p$$ on its diagonal.&#x20;

This formula captures the relationship between the input and output <mark style="color:blue;">**gradients**</mark> of the <mark style="color:blue;">**softmax function**</mark>, considering the interdependence of the output probabilities.

#### <mark style="color:green;">**Memory Limitations**</mark>

In standard attention implementations, the <mark style="color:blue;">**attention scores matrix**</mark> $$S$$ and the <mark style="color:blue;">**attention probabilities matrix**</mark> $$P$$are *<mark style="color:yellow;">**materialised**</mark>* and <mark style="color:yellow;">**stored in the high-bandwidth memory**</mark> (HBM) of the GPU.&#x20;

The memory requirement for storing these matrices is <mark style="color:yellow;">**proportional to**</mark> $$N^2$$, where $$N$$ is the <mark style="color:blue;">**sequence length**</mark>.

This <mark style="color:yellow;">**quadratic memory requirement**</mark> becomes a significant limitation when dealing with long sequences.&#x20;

In typical applications, the <mark style="color:blue;">**sequence length**</mark> $$N$$ can <mark style="color:yellow;">**range from 1,000 to 8,000 tokens**</mark>, while the <mark style="color:blue;">**head dimension**</mark> $$d$$ is usually <mark style="color:yellow;">**much smaller**</mark>, around 64-128.&#x20;

As a result, the memory consumption grows rapidly with increasing <mark style="color:blue;">**sequence length**</mark>, making it challenging to process long sequences efficiently.

To put this into perspective, let's consider an example where $$N$$ = 5,000 and $$d$$ = 64.&#x20;

The <mark style="color:blue;">**attention scores matrix**</mark> $$S$$ and the <mark style="color:blue;">**attention probabilities matrix**</mark> $$P$$ would each require storing 5,000 × 5,000 = 25,000,000 <mark style="color:blue;">**elements**</mark>.&#x20;

Assuming each element is a <mark style="color:blue;">**32-bit floating-point number**</mark>, the total memory requirement for both matrices would be approximately <mark style="color:yellow;">**200 MB**</mark>.&#x20;

This memory consumption can quickly become prohibitive, especially when dealing with larger sequence lengths or multiple attention heads.

<figure><img src="https://1839612753-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FpV8SlQaC976K9PPsjApL%2Fuploads%2FMz7ibS1CnjYpeywKPAeT%2Fimage.png?alt=media&#x26;token=b81b7870-202e-4bc8-8397-4a4f3671ee13" alt=""><figcaption><p>From the above image you can see that SRAM is faster than the HBM, but SRAM has very much limited memory as its fast cache memory. In the above image you can compare the speed of HBM and SRAM. HBM is used to store Tensors, while SRAM is used to perform compute operations on those Tensors.</p></figcaption></figure>

### <mark style="color:green;">**Conclusion**</mark>

The <mark style="color:blue;">**standard attention mechanism**</mark>, while powerful, *<mark style="color:yellow;">**suffers from a quadratic memory requirement**</mark>* that limits its scalability to long sequences.&#x20;

This limitation arises from the need to materialise and store the <mark style="color:blue;">**attention scores**</mark> and probabilities matrices, which *<mark style="color:yellow;">**grow quadratically**</mark>* with the <mark style="color:blue;">**sequence length**</mark>.

To address this challenge, various alternative <mark style="color:blue;">**attention mechanisms**</mark> have been proposed, such as sparse attention, linearised attention, and hierarchical attention.&#x20;

These mechanisms have aimed to reduce the memory footprint and computational complexity of the attention mechanism while still capturing the essential dependencies in the input sequence.

Flash Attention and Flash Attention 2 to date are the best solutions to the quadratic problem.

<details>

<summary><mark style="color:green;"><strong>The problem with the sequence length being larger than the head dimension</strong></mark></summary>

The significance of <mark style="color:blue;">**N**</mark> being much larger than <mark style="color:blue;">**d**</mark> lies in the memory complexity of the standard attention implementation.&#x20;

Since the <mark style="color:blue;">**attention scores matrix**</mark>**&#x20;**<mark style="color:yellow;">**S**</mark> and the <mark style="color:blue;">**attention probabilities matrix**</mark> <mark style="color:yellow;">**P**</mark> have dimensions <mark style="color:yellow;">**N×N**</mark>, the memory required to store these matrices grows quadratically with the sequence length <mark style="color:yellow;">**N**</mark>.&#x20;

For example, if <mark style="color:yellow;">**N**</mark> is 1,000 (1k), the <mark style="color:blue;">**matrices**</mark> would have 1,000,000 (1 million) elements, and if <mark style="color:yellow;">**N**</mark> is 8,000 (8k), the <mark style="color:blue;">**matrices**</mark> would have 64,000,000 (64 million) elements!

In contrast, the <mark style="color:blue;">**head dimension**</mark> <mark style="color:yellow;">**d**</mark> is typically <mark style="color:yellow;">**much smaller**</mark> than <mark style="color:yellow;">**N**</mark>.&#x20;

Even with a head dimension of 128, which is on the higher end of the typical range, it is still significantly smaller than the sequence lengths encountered in practice.

This difference in scale between <mark style="color:yellow;">**N**</mark> and <mark style="color:yellow;">**d**</mark> highlights the memory inefficiency of the standard attention implementation.&#x20;

As the <mark style="color:blue;">**sequence length**</mark> <mark style="color:yellow;">**N**</mark> grows, the *<mark style="color:yellow;">**memory required to store the attention matrices increases quadratically**</mark>*, while the <mark style="color:blue;">**head dimension**</mark> <mark style="color:yellow;">**d**</mark> remains relatively small.&#x20;

This leads to a bottleneck in terms of memory consumption, especially when dealing with long sequences, making it challenging to process such sequences efficiently using the standard attention implementation.

</details>

### <mark style="color:purple;">Flash Attention 2</mark>

<mark style="color:blue;">**FlashAttention-2**</mark> is an <mark style="color:yellow;">**improved version**</mark> of the <mark style="color:blue;">**FlashAttention**</mark> algorithm.  It aims to further optimise the <mark style="color:blue;">**attention mechanism**</mark> for longer <mark style="color:blue;">**sequences**</mark> by reducing memory usage and improving computational efficiency.&#x20;

FlashAttention-2 addresses the quadratic memory requirement issue of the standard <mark style="color:blue;">**attention mechanism**</mark> by employing several techniques to reduce memory reads/writes and improve parallelism.&#x20;

#### <mark style="color:green;">**Tiling and Block-wise Computation**</mark>

FlashAttention-2 employs a technique called <mark style="color:blue;">**tiling and block-wise computation**</mark> to reduce memory input/output (I/O) operations.&#x20;

This is achieved by dividing the <mark style="color:blue;">**input matrices**</mark> $$(Q, K, V)$$and <mark style="color:blue;">**output matrix**</mark> $$(O)$$ into smaller blocks.

The purpose of this division is to ensure that these smaller blocks can fit into the GPU's on-chip <mark style="color:blue;">**SRAM (Static Random Access Memory)**</mark>, which is much faster than the main <mark style="color:blue;">**high-bandwidth memory (HBM)**</mark>.&#x20;

By operating on these smaller blocks, FlashAttention-2 minimises the need for frequent data transfers between the main memory and the GPU, thereby reducing the memory I/O bottleneck and improving overall performance.

<figure><img src="https://1839612753-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FpV8SlQaC976K9PPsjApL%2Fuploads%2FZ6keJo66TiN5lqQBrZqB%2Fimage.png?alt=media&#x26;token=e8b6942c-5499-4e40-b1ab-f013545af0ea" alt="" width="540"><figcaption><p>You can see the concept here - tiling means breaking or chopping the the matrix into multiple block and then performing the computation in SRAM instead of passing a complete N*N matrix from HBM to SRAM.  This is because SRAM is very limited in capacity - the entire matrix cannot be loaded into SRAM.</p></figcaption></figure>

#### <mark style="color:green;">**Online Softmax Computation**</mark>

* FlashAttention-2 uses the <mark style="color:yellow;">**online softmax technique**</mark> to enable block-wise computation of attention.
* Instead of computing the <mark style="color:blue;">**softmax function**</mark> over the entire <mark style="color:blue;">**attention scores matrix**</mark> $$(S)$$, FlashAttention-2 <mark style="color:yellow;">**computes the softmax function for each block independently**</mark>.
* It maintains a running maximum and a running sum of exponentials for each block, which allows for the correct computation of <mark style="color:blue;">**attention probabilities**</mark> without materialising the entire $$S$$ and $$P$$ matrices.
* The <mark style="color:yellow;">**online softmax technique**</mark> ensures that the output of each block is correctly scaled, resulting in the same output as the standard <mark style="color:blue;">**attention mechanism**</mark>, but with reduced memory requirements.

#### <mark style="color:green;">**Recomputation and Gradient Checkpointing**</mark>

* In the <mark style="color:blue;">**backward pass**</mark>, FlashAttention-2 <mark style="color:yellow;">**avoids storing the large intermediate matrices**</mark> $$S$$ and $$P$$ by recomputing their values once the blocks of <mark style="color:blue;">**inputs**</mark> $$(Q, K, V)$$ are loaded into on-chip SRAM.
* By not having to save the $$S$$ and $$P$$ matrices, which have a size of $$N×N$$ (where $$N$$ is the <mark style="color:blue;">**sequence length**</mark>), FlashAttention-2 achieves a significant memory saving, *<mark style="color:yellow;">**reducing the memory requirement from quadratic to linear in sequence length**</mark>*.
* The recomputation strategy, also known as <mark style="color:blue;">**gradient checkpointing**</mark>, trades off some additional computation for reduced memory usage, enabling the processing of longer <mark style="color:blue;">**sequences**</mark> with limited memory resources.

#### <mark style="color:green;">**Improved Parallelism**</mark>

* FlashAttention-2 <mark style="color:yellow;">**enhances parallelism**</mark> by processing different blocks of the attention computation on <mark style="color:yellow;">**different thread blocks**</mark> of the GPU.
* In the <mark style="color:blue;">**forward pass**</mark>**,** FlashAttention-2 parallelises the computation across the <mark style="color:blue;">**sequence length**</mark> dimension, *<mark style="color:yellow;">**in addition to the batch and head dimensions**</mark>*.  Each thread block handles a block of rows of the <mark style="color:blue;">**attention matrix**</mark>, allowing for efficient utilisation of GPU resources, especially when dealing with long <mark style="color:blue;">**sequences**</mark> and small <mark style="color:blue;">**batch sizes**</mark>.
* In the <mark style="color:blue;">**backward pass**</mark>, FlashAttention-2 *<mark style="color:yellow;">**parallelises the computation across the sequence length dimension as well**</mark>*, with each thread block handling a block of columns of the <mark style="color:blue;">**attention matrix**</mark>. This parallelisation scheme improves the utilisation of GPU resources and reduces the overall runtime.

<figure><img src="https://1839612753-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FpV8SlQaC976K9PPsjApL%2Fuploads%2FzBYWI5XAh7YLGQhaSpUz%2Fimage.png?alt=media&#x26;token=c8dd74a1-8eea-42a8-b1b7-f692d1b13990" alt="" width="548"><figcaption><p>From the paper: In the forward pass (left), we parallelise the workers (thread blocks) where each worker takes care of a block of rows of the attention matrix. In the backward pass (right), each worker takes care of a block of columns of the attention matrix</p></figcaption></figure>

#### <mark style="color:green;">**Optimised Work Partitioning**</mark>

* FlashAttention-2 optimises the work partitioning between different warps within each thread block to reduce communication and shared memory reads/writes.
* In the <mark style="color:blue;">**forward pass**</mark>, FlashAttention-2 splits the <mark style="color:blue;">**query matrix**</mark> $$(Q)$$ across different warps, while keeping the <mark style="color:blue;">**key**</mark> $$(K)$$ and <mark style="color:blue;">**value**</mark> $$(V)$$ matrices accessible by all warps.  This approach eliminates the need for communication between warps and reduces shared memory reads/writes compared to the "split-K" scheme used in the original FlashAttention.
* In the <mark style="color:blue;">**backward pass**</mark>, FlashAttention-2 also avoids the "split-K" scheme and partitions the work between warps to minimise shared memory reads/writes, leading to improved performance.

#### <mark style="color:purple;">**What is the "Split K" fix**</mark>

* In the context of <mark style="color:blue;">**attention computation**</mark>, "split-K" refers to a scheme where the <mark style="color:blue;">**key matrix**</mark> $$(K)$$ is <mark style="color:yellow;">**split across different warps**</mark> (groups of threads) within a thread block.
* In the original FlashAttention implementation, the <mark style="color:blue;">**key matrix**</mark> $$(K)$$ and <mark style="color:blue;">**value matrix**</mark> $$(V)$$ are divided and assigned to different warps, while the <mark style="color:blue;">**query matrix**</mark> $$(Q)$$is accessible by all warps.
* After each warp performs a matrix multiplication between its assigned portion of $$Q$$ and $$K$$ to compute a slice of the <mark style="color:blue;">**attention scores**</mark> $$(QK^⊤)$$, the warps need to communicate and exchange intermediate results through shared memory.
* This communication and shared memory reads/writes introduce overhead and can slow down the computation.
* <mark style="color:yellow;">**FlashAttention-2 addresses this issue**</mark> by using a *<mark style="color:yellow;">**different work partitioning scheme**</mark>* in the forward pass, where the <mark style="color:blue;">**query matrix**</mark> $$Q$$is split across warps instead of the <mark style="color:blue;">**key matrix**</mark> $$K$$.

<figure><img src="https://1839612753-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FpV8SlQaC976K9PPsjApL%2Fuploads%2FjyNUbItVm4CalzbQEiWI%2Fimage.png?alt=media&#x26;token=7b6d5991-0715-4a5d-9215-8f3ec089f9bc" alt=""><figcaption><p>Work partitioning between different warps in the forward pass</p></figcaption></figure>

### <mark style="color:purple;">Empirical Validation and Performance Gains</mark>

The FlashAttention-2 paper provides detailed benchmarks and performance comparisons that demonstrate the significant improvements achieved by the optimised <mark style="color:blue;">**attention mechanism**</mark>.&#x20;

#### <mark style="color:green;">Attention Computation Speed</mark>

FlashAttention-2 showcases remarkable speed improvements in attention computation compared to previous implementations.&#x20;

The paper presents comprehensive benchmarks on <mark style="color:blue;">**NVIDIA A100 GPUs**</mark>, comparing FlashAttention-2 with PyTorch's standard attention implementation, the original FlashAttention, and FlashAttention implemented in the Triton language.&#x20;

FlashAttention-2 achieves up to <mark style="color:yellow;">**3 times faster attention computation**</mark> compared to the original FlashAttention and up to *<mark style="color:yellow;">**10 times faster than PyTorch's standard attention implementation**</mark>*.&#x20;

For example, on <mark style="color:blue;">**sequences of length**</mark>**&#x20;**<mark style="color:yellow;">**2048**</mark> with a <mark style="color:blue;">**head dimension of**</mark>**&#x20;**<mark style="color:yellow;">**128**</mark>, FlashAttention-2 attains an impressive <mark style="color:yellow;">**187 TFLOPs/s**</mark> in the <mark style="color:blue;">**forward pass**</mark>, surpassing FlashAttention's <mark style="color:yellow;">**120 TFLOPs/s**</mark> and PyTorch's <mark style="color:yellow;">**60 TFLOPs/s**</mark>.&#x20;

These speed improvements are consistent across various <mark style="color:blue;">**sequence lengths**</mark> and <mark style="color:blue;">**head dimensions**</mark>, showcasing FlashAttention-2's ability to efficiently utilise GPU resources and minimise memory bottlenecks.&#x20;

#### <mark style="color:green;">Backward Pass Efficiency</mark>

The <mark style="color:blue;">**backward pass**</mark> in <mark style="color:blue;">**attention mechanisms**</mark> is known to be computationally challenging due to the dependencies between <mark style="color:blue;">**gradients**</mark> and the need to store intermediate values.&#x20;

However, FlashAttention-2 tackles this challenge head-on and achieves significant speedups in the <mark style="color:blue;">**backward pass**</mark> as well.   The paper presents detailed benchmarks comparing the <mark style="color:blue;">**backward pass**</mark> performance of FlashAttention-2 with other implementations.

On <mark style="color:blue;">**A100 GPUs**</mark>, FlashAttention-2 demonstrates up to <mark style="color:yellow;">**2 times faster backward pass**</mark> compared to the original FlashAttention

It also outperforms PyTorch's standard attention implementation by a significant margin.&#x20;

For instance, on <mark style="color:blue;">**sequences of length**</mark> <mark style="color:yellow;">**4096**</mark> with a <mark style="color:blue;">**head dimension**</mark> of <mark style="color:yellow;">**128**</mark>, FlashAttention-2 achieves <mark style="color:yellow;">**165 TFLOPs/s**</mark> in the <mark style="color:blue;">**backward pass**</mark>, while FlashAttention reaches <mark style="color:yellow;">**80 TFLOPs/s**</mark> and PyTorch's implementation obtains only <mark style="color:yellow;">**45 TFLOPs/s**</mark>.&#x20;

This showcases FlashAttention-2's ability to *<mark style="color:yellow;">**efficiently compute gradients and optimise memory usage**</mark>* in the <mark style="color:blue;">**backward pass**</mark>.

The implications of FlashAttention-2's performance improvements are far-reaching. It enables faster training and inference of Transformer models, opening up possibilities for handling longer <mark style="color:blue;">**sequences**</mark> and more complex tasks.&#x20;

### <mark style="color:purple;">The Impact of FlashAttention-2</mark>

FlashAttention-2 represented a significant advancement in the field of efficient Transformer models, particularly in addressing the limitations of the standard <mark style="color:blue;">**attention mechanism**</mark>.&#x20;

FlashAttention-2 achieves remarkable improvements in both speed and memory efficiency.&#x20;

The empirical results demonstrate that FlashAttention-2 can significantly outperform previous implementations, including the original FlashAttention and standard attention implementations in PyTorch.

The implications of FlashAttention-2 extend beyond just raw performance gains. It opens up new possibilities for training and deploying large Transformer models capable of handling longer <mark style="color:blue;">**sequences**</mark> and more complex tasks.&#x20;

As the demand for AI models that can understand and generate longer contexts continues to grow, FlashAttention-2 provides a promising solution to address the challenges associated with scaling attention mechanisms.

### <mark style="color:purple;">Paper References</mark>

Ainslie et al. (2023) and Shazeer (2019) discuss techniques for training and optimising Transformer models, such as multi-query attention and fast transformer decoding.

Several papers, including Beltagy et al. (2020), Katharopoulos et al. (2020), Kitaev et al. (2020), Roy et al. (2021), Wang et al. (2020), and Zaheer et al. (2020), propose alternative attention mechanisms to address the quadratic complexity of standard attention. These techniques include sparse attention, linearized attention, and hierarchical attention.

Chen et al. (2021) and Choromanski et al. (2020) introduce novel attention mechanisms, such as Scatterbrain and Performers, to improve the efficiency of Transformers.

Dao et al. (2022) propose the original FlashAttention algorithm, which serves as the foundation for FlashAttention-2.

Jia and Van Sandt (2021) and Jia et al. (2018) provide insights into the architecture and performance characteristics of NVIDIA GPUs, which are relevant for optimising attention computations.

Lefaudeux et al. (2022) introduce the xformers library, which includes implementations of various efficient attention mechanisms.

Milakov and Gimelshein (2018) and Rabe and Staats (2021) discuss techniques for optimising the softmax computation in attention mechanisms.

OpenAI (2023) presents the technical details of the GPT-4 model, which demonstrates the importance of scaling Transformers to longer sequences.

Shoeybi et al. (2019) introduce Megatron-LM, a large-scale language model trained using model parallelism, highlighting the need for efficient attention mechanisms in training billion-parameter models.

Tillet et al. (2019) propose Triton, an intermediate language and compiler for tiled neural network computations, which is used in some implementations of FlashAttention.

Vaswani et al. (2017) introduce the original Transformer architecture and the standard attention mechanism, which serves as the basis for subsequent research on efficient attention.


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://training.continuumlabs.ai/inference/why-is-inference-important/flash-attention-2.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
