LogoLogo
Continuum WebsiteContinuum ApplicationsContinuum KnowledgeAxolotl Platform
Continuum Knowledge
Continuum Knowledge
  • Continuum
  • Data
    • Datasets
      • Pre Training Data
      • Types of Fine Tuning
      • Self Instruct Paper
      • Self-Alignment with Instruction Backtranslation
      • Systematic Evaluation of Instruction-Tuned Large Language Models on Open Datasets
      • Instruction Tuning
      • Instruction Fine Tuning - Alpagasus
      • Less is More For Alignment
      • Enhanced Supervised Fine Tuning
      • Visualising Data using t-SNE
      • UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction
      • Training and Evaluation Datasets
      • What is perplexity?
  • MODELS
    • Foundation Models
      • The leaderboard
      • Foundation Models
      • LLama 2 - Analysis
      • Analysis of Llama 3
      • Llama 3.1 series
      • Google Gemini 1.5
      • Platypus: Quick, Cheap, and Powerful Refinement of LLMs
      • Mixtral of Experts
      • Mixture-of-Agents (MoA)
      • Phi 1.5
        • Refining the Art of AI Training: A Deep Dive into Phi 1.5's Innovative Approach
      • Phi 2.0
      • Phi-3 Technical Report
  • Training
    • The Fine Tuning Process
      • Why fine tune?
        • Does Fine-Tuning LLMs on New Knowledge Encourage Hallucinations?
        • Explanations in Fine Tuning
      • Tokenization
        • Tokenization Is More Than Compression
        • Tokenization - SentencePiece
        • Tokenization explore
        • Tokenizer Choice For LLM Training: Negligible or Crucial?
        • Getting the most out of your tokenizer for pre-training and domain adaptation
        • TokenMonster
      • Parameter Efficient Fine Tuning
        • P-Tuning
          • The Power of Scale for Parameter-Efficient Prompt Tuning
        • Prefix-Tuning: Optimizing Continuous Prompts for Generation
        • Harnessing the Power of PEFT: A Smarter Approach to Fine-tuning Pre-trained Models
        • What is Low-Rank Adaptation (LoRA) - explained by the inventor
        • Low Rank Adaptation (Lora)
        • Practical Tips for Fine-tuning LMs Using LoRA (Low-Rank Adaptation)
        • QLORA: Efficient Finetuning of Quantized LLMs
        • Bits and Bytes
        • The Magic behind Qlora
        • Practical Guide to LoRA: Tips and Tricks for Effective Model Adaptation
        • The quantization constant
        • QLORA: Efficient Finetuning of Quantized Language Models
        • QLORA and Fine-Tuning of Quantized Language Models (LMs)
        • ReLoRA: High-Rank Training Through Low-Rank Updates
        • SLoRA: Federated Parameter Efficient Fine-Tuning of Language Models
        • GaLora: Memory-Efficient LLM Training by Gradient Low-Rank Projection
      • Hyperparameters
        • Batch Size
        • Padding Tokens
        • Mixed precision training
        • FP8 Formats for Deep Learning
        • Floating Point Numbers
        • Batch Size and Model loss
        • Batch Normalisation
        • Rethinking Learning Rate Tuning in the Era of Language Models
        • Sample Packing
        • Gradient accumulation
        • A process for choosing the learning rate
        • Learning Rate Scheduler
        • Checkpoints
        • A Survey on Efficient Training of Transformers
        • Sequence Length Warmup
        • Understanding Training vs. Evaluation Data Splits
        • Cross-entropy loss
        • Weight Decay
        • Optimiser
        • Caching
      • Training Processes
        • Extending the context window
        • PyTorch Fully Sharded Data Parallel (FSDP)
        • Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
        • YaRN: Efficient Context Window Extension of Large Language Models
        • Sliding Window Attention
        • LongRoPE
        • Reinforcement Learning
        • An introduction to reinforcement learning
        • Reinforcement Learning from Human Feedback (RLHF)
        • Direct Preference Optimization: Your Language Model is Secretly a Reward Model
  • INFERENCE
    • Why is inference important?
      • Grouped Query Attention
      • Key Value Cache
      • Flash Attention
      • Flash Attention 2
      • StreamingLLM
      • Paged Attention and vLLM
      • TensorRT-LLM
      • Torchscript
      • NVIDIA L40S GPU
      • Triton Inference Server - Introduction
      • Triton Inference Server
      • FiDO: Fusion-in-Decoder optimised for stronger performance and faster inference
      • Is PUE a useful measure of data centre performance?
      • SLORA
  • KNOWLEDGE
    • Vector Databases
      • A Comprehensive Survey on Vector Databases
      • Vector database management systems: Fundamental concepts, use-cases, and current challenges
      • Using the Output Embedding to Improve Language Models
      • Decoding Sentence-BERT
      • ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT
      • SimCSE: Simple Contrastive Learning of Sentence Embeddings
      • Questions Are All You Need to Train a Dense Passage Retriever
      • Improving Text Embeddings with Large Language Models
      • Massive Text Embedding Benchmark
      • RocketQAv2: A Joint Training Method for Dense Passage Retrieval and Passage Re-ranking
      • LLM2Vec: Large Language Models Are Secretly Powerful Text Encoders
      • Embedding and Fine-Tuning in Neural Language Models
      • Embedding Model Construction
      • Demystifying Embedding Spaces using Large Language Models
      • Fine-Tuning Llama for Multi-Stage Text Retrieval
      • Large Language Model Based Text Augmentation Enhanced Personality Detection Model
      • One Embedder, Any Task: Instruction-Finetuned Text Embeddings
      • Vector Databases are not the only solution
      • Knowledge Graphs
        • Harnessing Knowledge Graphs to Elevate AI: A Technical Exploration
        • Unifying Large Language Models and Knowledge Graphs: A Roadmap
      • Approximate Nearest Neighbor (ANN)
      • High Dimensional Data
      • Principal Component Analysis (PCA)
      • Vector Similarity Search - HNSW
      • FAISS (Facebook AI Similarity Search)
      • Unsupervised Dense Retrievers
    • Retrieval Augmented Generation
      • Retrieval-Augmented Generation for Large Language Models: A Survey
      • Fine-Tuning or Retrieval?
      • Revolutionising Information Retrieval: The Power of RAG in Language Models
      • A Survey on Retrieval-Augmented Text Generation
      • REALM: Retrieval-Augmented Language Model Pre-Training
      • Retrieve Anything To Augment Large Language Models
      • Generate Rather Than Retrieve: Large Language Models Are Strong Context Generators
      • Active Retrieval Augmented Generation
      • DSPy: LM Assertions: Enhancing Language Model Pipelines with Computational Constraints
      • DSPy: Compiling Declarative Language Model Calls
      • DSPy: In-Context Learning for Extreme Multi-Label Classification
      • Optimizing Instructions and Demonstrations for Multi-Stage Language Model Programs
      • HYDE: Revolutionising Search with Hypothetical Document Embeddings
      • Enhancing Recommender Systems with Large Language Model Reasoning Graphs
      • Retrieval Augmented Generation (RAG) versus fine tuning
      • RAFT: Adapting Language Model to Domain Specific RAG
      • Summarisation Methods and RAG
      • Lessons Learned on LLM RAG Solutions
      • Stanford: Retrieval Augmented Language Models
      • Overview of RAG Approaches with Vector Databases
      • Mastering Chunking in Retrieval-Augmented Generation (RAG) Systems
    • Semantic Routing
    • Resource Description Framework (RDF)
  • AGENTS
    • What is agency?
      • Rephrase and Respond: Let Large Language Models Ask Better Questions for Themselves
      • Types of Agents
      • The risk of AI agency
      • Understanding Personality in Large Language Models: A New Frontier in AI Psychology
      • AI Agents - Reasoning, Planning, and Tool Calling
      • Personality and Brand
      • Agent Interaction via APIs
      • Bridging Minds and Machines: The Legacy of Newell, Shaw, and Simon
      • A Survey on Language Model based Autonomous Agents
      • Large Language Models as Agents
      • AI Reasoning: A Deep Dive into Chain-of-Thought Prompting
      • Enhancing AI Reasoning with Self-Taught Reasoner (STaR)
      • Exploring the Frontier of AI: The "Tree of Thoughts" Framework
      • Toolformer: Revolutionising Language Models with API Integration - An Analysis
      • TaskMatrix.AI: Bridging Foundational AI Models with Specialised Systems for Enhanced Task Completion
      • Unleashing the Power of LLMs in API Integration: The Rise of Gorilla
      • Andrew Ng's presentation on AI agents
      • Making AI accessible with Andrej Karpathy and Stephanie Zhan
  • Regulation and Ethics
    • Regulation and Ethics
      • Privacy
      • Detecting AI Generated content
      • Navigating the IP Maze in AI: The Convergence of Blockchain, Web 3.0, and LLMs
      • Adverse Reactions to generative AI
      • Navigating the Ethical Minefield: The Challenge of Security in Large Language Models
      • Navigating the Uncharted Waters: The Risks of Autonomous AI in Military Decision-Making
  • DISRUPTION
    • Data Architecture
      • What is a data pipeline?
      • What is Reverse ETL?
      • Unstructured Data and Generatve AI
      • Resource Description Framework (RDF)
      • Integrating generative AI with the Semantic Web
    • Search
      • BM25 - Search Engine Ranking Function
      • BERT as a reranking engine
      • BERT and Google
      • Generative Engine Optimisation (GEO)
      • Billion-scale similarity search with GPUs
      • FOLLOWIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions
      • Neural Collaborative Filtering
      • Federated Neural Collaborative Filtering
      • Latent Space versus Embedding Space
      • Improving Text Embeddings with Large Language Models
    • Recommendation Engines
      • On Interpretation and Measurement of Soft Attributes for Recommendation
      • A Survey on Large Language Models for Recommendation
      • Model driven recommendation systems
      • Recommender AI Agent: Integrating Large Language Models for Interactive Recommendations
      • Foundation Models for Recommender Systems
      • Exploring the Impact of Large Language Models on Recommender Systems: An Extensive Review
      • AI driven recommendations - harming autonomy?
    • Logging
      • A Taxonomy of Anomalies in Log Data
      • Deeplog
      • LogBERT: Log Anomaly Detection via BERT
      • Experience Report: Deep Learning-based System Log Analysis for Anomaly Detection
      • Log-based Anomaly Detection with Deep Learning: How Far Are We?
      • Deep Learning for Anomaly Detection in Log Data: A Survey
      • LogGPT
      • Adaptive Semantic Gate Networks (ASGNet) for log-based anomaly diagnosis
  • Infrastructure
    • The modern data centre
      • Enhancing Data Centre Efficiency: Strategies to Improve PUE
      • TCO of NVIDIA GPUs and falling barriers to entry
      • Maximising GPU Utilisation with Kubernetes and NVIDIA GPU Operator
      • Data Centres
      • Liquid Cooling
    • Servers and Chips
      • The NVIDIA H100 GPU
      • NVIDIA H100 NVL
      • Lambda Hyperplane 8-H100
      • NVIDIA DGX Servers
      • NVIDIA DGX-2
      • NVIDIA DGX H-100 System
      • NVLink Switch
      • Tensor Cores
      • NVIDIA Grace Hopper Superchip
      • NVIDIA Grace CPU Superchip
      • NVIDIA GB200 NVL72
      • Hopper versus Blackwell
      • HGX: High-Performance GPU Platforms
      • ARM Chips
      • ARM versus x86
      • RISC versus CISC
      • Introduction to RISC-V
    • Networking and Connectivity
      • Infiniband versus Ethernet
      • NVIDIA Quantum InfiniBand
      • PCIe (Peripheral Component Interconnect Express)
      • NVIDIA ConnectX InfiniBand adapters
      • NVMe (Non-Volatile Memory Express)
      • NVMe over Fabrics (NVMe-oF)
      • NVIDIA Spectrum-X
      • NVIDIA GPUDirect
      • Evaluating Modern GPU Interconnect
      • Scalable Hierarchical Aggregation and Reduction Protocol (SHARP)
      • Next-generation networking in AI environments
      • NVIDIA Collective Communications Library (NCCL)
    • Data and Memory
      • NVIDIA BlueField Data Processing Units (DPUs)
      • Remote Direct Memory Access (RDMA)
      • High Bandwidth Memory (HBM3)
      • Flash Memory
      • Model Requirements
      • Calculating GPU memory for serving LLMs
      • Transformer training costs
      • GPU Performance Optimisation
    • Libraries and Complements
      • NVIDIA Base Command
      • NVIDIA AI Enterprise
      • CUDA - NVIDIA GTC 2024 presentation
      • RAPIDs
      • RAFT
    • Vast Data Platform
      • Vast Datastore
      • Vast Database
      • Vast Data Engine
      • DASE (Disaggregated and Shared Everything)
      • Dremio and VAST Data
    • Storage
      • WEKA: A High-Performance Storage Solution for AI Workloads
      • Introduction to NVIDIA GPUDirect Storage (GDS)
        • GDS cuFile API
      • NVIDIA Magnum IO GPUDirect Storage (GDS)
      • Vectors in Memory
Powered by GitBook
LogoLogo

Continuum - Accelerated Artificial Intelligence

  • Continuum Website
  • Axolotl Platform

Copyright Continuum Labs - 2023

On this page
  • Interview with the genius inventor
  • The Leap from FlashAttention to FlashAttention-2
  • Some definitions
  • GEMM and the Attention Mechanism
  • The shortfalls of the standard attention mechanism
  • Forward and Backward Pass
  • Conclusion
  • Flash Attention 2
  • Empirical Validation and Performance Gains
  • The Impact of FlashAttention-2
  • Paper References

Was this helpful?

  1. INFERENCE
  2. Why is inference important?

Flash Attention 2

The seminal July 2023 paper

PreviousFlash AttentionNextStreamingLLM

Last updated 11 months ago

Was this helpful?

FlashAttention-2, a sequel to the original FlashAttention demonstrates a process that can significantly boost the efficiency of Transformer models.

With Transformers being at the heart of generative AI models for text, code, audio, and video - their scalability is paramount.

However, the attention mechanism, a critical component of Transformers, has been a bottleneck due to its quadratic memory and runtime increase with sequence length.

Therefore the main goal is to address the quadratic runtime and memory requirements of the attention layer, which hinders the scalability of Transformers to longer sequences.

FlashAttention-2, developed by the genius Tri Dao from the departments of Computer Science at Princeton University and Stanford University, tackled this issue head-on.

Interview with the genius inventor

A summary of the video transcript with Tri Dao

Tri Dao, the inventor of Flash Attention and Flash Attention 2, shares his insights on the development of these techniques, the current state of AI research, and his perspectives on the future of the field. Let's break down the key points from the discussion.

Insights and Ideas behind Flash Attention

  • Dao and his team focused on making attention more memory-efficient without approximations, as most practitioners preferred exact methods over approximate ones.

  • They realised that the majority of time in attention computation was spent on memory reading and writing, not floating-point operations (FLOPs).

  • By combining ideas from the systems side (kernel fusion, tiling) and the machine learning side (the online softmax trick), they developed Flash Attention, which reduces memory usage from quadratic to linear in sequence length.

  • Flash Attention provides a 2-4x speedup compared to standard attention implementation, allowing for training with longer sequences at the same cost.

Development of Flash Attention 2

  • Dao started working on Flash Attention 2 as an excuse to play with NVIDIA's CUTLASS library, which provides primitives for efficient GPU operations.

  • He realised there were still inefficiencies in Flash Attention and made tweaks to reduce non-matmul FLOPs, as modern GPUs have specialised units (e.g., Tensor Cores) that make matmul much faster than non-matmul operations.

  • Flash Attention 2 also parallelises the computation across different thread blocks and partitions the work between warps to reduce shared memory access.

  • These optimisations result in a 2x speedup compared to Flash Attention, approaching the efficiency of highly-optimized matrix multiplication routines.

Explanation of GPU Performance Characteristics

  • GPUs have a memory hierarchy consisting of high-bandwidth memory (HBM) and on-chip SRAM (shared memory).

  • HBM is larger but slower, while SRAM is much smaller but faster.

  • The goal is to design algorithms that take advantage of this asymmetric memory hierarchy, loading data from HBM to SRAM, performing computations, and then writing results back to HBM.

  • As sequence lengths grow, the memory required for attention matrices increases quadratically, while the number of attention heads remains relatively small, leading to a memory inefficiency that Flash Attention aims to address.

Role of Academia and Industry in AI Research

  • Academia and industry play complementary roles in AI research.

  • Industry has an advantage in scaling models and providing infrastructure, while academia focuses on developing new ideas, understanding and explaining models, and taking riskier bets on novel architectures.

  • Dao believes it's important for researchers to maintain a balance between exploring fundamentals and keeping up with the latest developments in the field.

Transformer Alternatives and Future Directions

  • Dao is interested in exploring transformer alternatives, such as recurrent neural networks (RNNs) and state space methods, which could be more efficient for long sequences and high-throughput generation.

  • He believes these alternatives could have advantages in specific use cases, such as processing high-resolution images, audio, and video.

  • However, he also acknowledges the "hardware lottery" and the challenges of developing new architectures that can compete with the heavily-optimised transformer on current hardware and software frameworks.

The Leap from FlashAttention to FlashAttention-2

FlashAttention made waves by reordering the computation of attention to leverage the GPU memory hierarchy more effectively, significantly speeding up the process and reducing memory usage from quadratic to linear in sequence length.

This yielded a 2-4 times speedup over optimised baselines and up to 10-20 times memory savings, with no approximation.

Despite its success, FlashAttention still fell short of reaching the efficiency of optimised General Matrix Multiply (GEMM) operations, achieving only 25-40% of the theoretical maximum FLOPs/s.

Refresher: GPU Performance Characteristics

Compute elements

GPUs consist of compute elements, such as floating-point arithmetic units, which perform the actual computations. These units are designed to handle parallel processing tasks efficiently.

Specialised matrix multiply units

Modern GPUs, like Nvidia's Tensor Cores, have specialised units that accelerate matrix multiplication operations in low-precision formats, such as FP16 (half-precision floating-point) or BF16 (bfloat16).

These units are optimised for fast and efficient matrix multiplication, which is a common operation in deep learning workloads.

Memory hierarchy

GPUs have a memory hierarchy that consists of two main components:

High Bandwidth Memory (HBM): HBM is a type of high-performance memory that offers high bandwidth and capacity. For example, the A100 GPU has 40-80GB of HBM with a bandwidth of 1.5-2.0TB/s. HBM is used to store large datasets and intermediate results during computation.

On-chip SRAM (Shared Memory): GPUs also have on-chip SRAM, known as shared memory, which is a fast but limited-capacity memory. In the A100 GPU, each of the 108 streaming multiprocessors (SMs) has 192KB of shared memory with an estimated bandwidth of around 19TB/s. Shared memory is used for fast communication and data sharing between threads within a thread block.

L2 cache: GPUs also have an L2 cache, which is a larger but slower cache compared to shared memory. However, the L2 cache is not directly controllable by the programmer and is managed by the GPU's cache hierarchy.

Execution Model

Threads: GPUs employ a massive number of threads to execute operations in parallel. These threads are lightweight and can be quickly switched to hide memory latency and maximise utilisation of the GPU's compute resources.

Thread blocks: Threads are organized into thread blocks, which are scheduled to run on the GPU's streaming multiprocessors (SMs). Each thread block contains a set of threads that can cooperate and share resources, such as shared memory.

Warps: Within each thread block, threads are further grouped into warps, typically consisting of 32 threads. Threads within a warp execute in lockstep, meaning they perform the same instruction on different data elements simultaneously (SIMD - Single Instruction, Multiple Data). Warps are the basic unit of execution on the GPU.

Communication and synchronization

Shuffle instructions: Threads within a warp can communicate using fast shuffle instructions, which allow them to exchange data directly without going through shared memory. This is useful for efficient data sharing and reduction operations within a warp.

Shared memory: Warps within a thread block can communicate by reading from and writing to shared memory. Shared memory acts as a fast, software-managed cache that enables efficient data sharing and coordination between warps.

Kernel execution: A GPU kernel is a function that is executed in parallel by multiple threads on the GPU. Each kernel loads inputs from HBM into registers and shared memory, performs computations, and then writes the outputs back to HBM. The kernel is launched with a specified grid and block size, which determines the number of thread blocks and threads per block.

Understanding these GPU performance characteristics and the execution model optimises GPU code and achieves high performance.

By leveraging the specialised matrix multiply units, using the memory hierarchy effectively, and organising threads and thread blocks appropriately, developers can maximise the utilisation of GPU resources and minimise memory latency, ultimately leading to faster execution of parallel workloads.

Some definitions

Sequence length refers to the number of tokens in an input sequence. For example, in natural language processing tasks, the sequence length is the number of words or subwords in a given text input.

Batch Dimension: Data is often processed in batches to improve computational efficiency and convergence properties. The batch dimension refers to the number of samples or sequences that are processed together in a single forward or backward pass of the model.

Instead of processing one sequence at a time, multiple sequences are stacked together to form a batch. This allows for parallel computation and better utilisation of GPUs.

GEMM: GEMM stands for General Matrix Multiplication, which is a fundamental operation in linear algebra and is used in attention mechanisms.

GEMM involves multiplying two matrices (often denoted as AAA and BBB) and adding the result to a third matrix (CCC).

The standard formula for GEMM is:

C=αAB+βCC = \alpha AB + \beta CC=αAB+βC

where α\alphaα and β\betaβ are scalar coefficients (usually set to 1 and 0, respectively, for standard matrix multiplication).

GEMM and the Attention Mechanism

In the context of attention mechanisms, GEMM is used to compute the attention scores (QK⊤QK^⊤QK⊤) and the attention output (PVPVPV), where:

  • QQQ is the query matrix

  • KKK is the key matrix

  • VVV is the value matrix

  • PPP is the attention probability matrix (obtained by applying softmax to the attention scores)

The dimensions of these matrices are:

  • Q,K,V∈RQ, K, V ∈ ℝQ,K,V∈R, with dimensions (N,d)(N, d)(N,d), where NNN is the sequence length and ddd is the head dimension

  • P∈RP ∈ ℝP∈R with dimensions (N,N) (N, N)(N,N)

  • Output O∈RO ∈ ℝO∈R, with dimensions (N×d) (N × d)(N×d)

In this notation:

  1. The symbol "∈" is the mathematical notation for "belongs to" or "is an element of." In this context, it is used to indicate that a matrix belongs to a specific set of dimensions.

  2. RℝR represents the set of real numbers

  3. (N,d)(N, d)(N,d) and (N,N)(N, N)(N,N) represent the dimensions of the matrices, with NN Nbeing the number of rows and ddd or NNN being the number of columns, depending on the matrix

The shortfalls of the standard attention mechanism

The standard attention mechanism is a fundamental component of Transformer models.

However, as highlighted, it suffers from a significant limitation: its inability to scale efficiently as the sequence length (the number of tokens or words in the input) increases.

Attention Mechanism - how does it work?

Given an input sequence of length NNN and a head dimension ddd, the goal of the attention mechanism is to compute the attention output OOO, which is an N×d N × dN×d matrix.

The attention mechanism then operates as follows:

Compute the attention scores S

S=QK⊤S = QK^⊤S=QK⊤ where QQQ and KKK are the query and key matrices, respectively, both of size N×dN × dN×d.

The attention scores SSS is an N×NN × NN×N matrix that represents the similarity or relevance between each pair of tokens in the input sequence.

Apply the softmax function to the attention scores S row-wise

P=softmax(S)P = softmax(S)P=softmax(S)

The softmax function normalises the attention scores, converting them into attention probabilities PPP which is also an N×NN × NN×N matrix.

Each row of PPP sums up to 1, representing the distribution of attention weights for a specific token.

Compute the attention output OOO

O=PVO = PVO=PV where VVV is the value matrix of size N×dN × dN×d.

The attention output OOO is an N×dN × dN×d matrix that represents the weighted sum of the value vectors, with the weights determined by the attention probabilities PPP.

Forward and Backward Pass

In the forward pass, the input sequences are processed in parallel across the batch dimension - multiple sequences are processed simultaneously - and the attention computation is performed independently for each head.

This parallelisation enables efficient utilisation of hardware resources and faster computation.

The backward pass, which is crucial for training the model, involves computing the gradients of the input sequences with respect to the output gradients.

Let dOdO dO be the gradient of the attention output OOO with respect to some loss function.

By applying the chain rule (backpropagation), we can compute the gradients as follows:

  1. Compute the gradient of VVV: dV=P⊤dO dV = P^⊤ dOdV=P⊤dO

  2. Compute the gradient of PPP: dP=dOV⊤dP = dO V^⊤dP=dOV⊤

  3. Compute the gradient of SSS using the gradient of the softmax function: dS=dsoftmax(dP)dS = dsoftmax(dP)dS=dsoftmax(dP)

  4. Compute the gradient of QQQ: dQ=dSKdQ = dS KdQ=dSK

  5. Compute the gradient of KKK: dK=Q⊤dSdK = Q^⊤ dSdK=Q⊤dS

The dsoftmaxdsoftmax dsoftmax function denotes the gradient of the softmax function applied row-wise.

It can be shown that if p=softmax(s)p = softmax(s)p=softmax(s) for some vectors sss and ppp, then with the output gradient dpdpdp, the input gradient dsdsds is given by:

ds=(diag(p)−pp⊤)dpds = (diag(p) - pp^⊤) dp ds=(diag(p)−pp⊤)dp

where diag(p)diag(p)diag(p) is a diagonal matrix with the elements of ppp on its diagonal.

This formula captures the relationship between the input and output gradients of the softmax function, considering the interdependence of the output probabilities.

Memory Limitations

In standard attention implementations, the attention scores matrix SSS and the attention probabilities matrix PP Pare materialised and stored in the high-bandwidth memory (HBM) of the GPU.

The memory requirement for storing these matrices is proportional to N2N^2N2, where NNN is the sequence length.

This quadratic memory requirement becomes a significant limitation when dealing with long sequences.

In typical applications, the sequence length NNN can range from 1,000 to 8,000 tokens, while the head dimension ddd is usually much smaller, around 64-128.

As a result, the memory consumption grows rapidly with increasing sequence length, making it challenging to process long sequences efficiently.

To put this into perspective, let's consider an example where NNN = 5,000 and ddd = 64.

The attention scores matrix SSS and the attention probabilities matrix PPP would each require storing 5,000 × 5,000 = 25,000,000 elements.

Assuming each element is a 32-bit floating-point number, the total memory requirement for both matrices would be approximately 200 MB.

This memory consumption can quickly become prohibitive, especially when dealing with larger sequence lengths or multiple attention heads.

Conclusion

The standard attention mechanism, while powerful, suffers from a quadratic memory requirement that limits its scalability to long sequences.

This limitation arises from the need to materialise and store the attention scores and probabilities matrices, which grow quadratically with the sequence length.

To address this challenge, various alternative attention mechanisms have been proposed, such as sparse attention, linearised attention, and hierarchical attention.

These mechanisms have aimed to reduce the memory footprint and computational complexity of the attention mechanism while still capturing the essential dependencies in the input sequence.

Flash Attention and Flash Attention 2 to date are the best solutions to the quadratic problem.

The problem with the sequence length being larger than the head dimension

The significance of N being much larger than d lies in the memory complexity of the standard attention implementation.

Since the attention scores matrix S and the attention probabilities matrix P have dimensions N×N, the memory required to store these matrices grows quadratically with the sequence length N.

For example, if N is 1,000 (1k), the matrices would have 1,000,000 (1 million) elements, and if N is 8,000 (8k), the matrices would have 64,000,000 (64 million) elements!

In contrast, the head dimension d is typically much smaller than N.

Even with a head dimension of 128, which is on the higher end of the typical range, it is still significantly smaller than the sequence lengths encountered in practice.

This difference in scale between N and d highlights the memory inefficiency of the standard attention implementation.

As the sequence length N grows, the memory required to store the attention matrices increases quadratically, while the head dimension d remains relatively small.

This leads to a bottleneck in terms of memory consumption, especially when dealing with long sequences, making it challenging to process such sequences efficiently using the standard attention implementation.

Flash Attention 2

FlashAttention-2 is an improved version of the FlashAttention algorithm. It aims to further optimise the attention mechanism for longer sequences by reducing memory usage and improving computational efficiency.

FlashAttention-2 addresses the quadratic memory requirement issue of the standard attention mechanism by employing several techniques to reduce memory reads/writes and improve parallelism.

Tiling and Block-wise Computation

FlashAttention-2 employs a technique called tiling and block-wise computation to reduce memory input/output (I/O) operations.

This is achieved by dividing the input matrices (Q,K,V)(Q, K, V) (Q,K,V)and output matrix (O)(O)(O) into smaller blocks.

The purpose of this division is to ensure that these smaller blocks can fit into the GPU's on-chip SRAM (Static Random Access Memory), which is much faster than the main high-bandwidth memory (HBM).

By operating on these smaller blocks, FlashAttention-2 minimises the need for frequent data transfers between the main memory and the GPU, thereby reducing the memory I/O bottleneck and improving overall performance.

Online Softmax Computation

  • FlashAttention-2 uses the online softmax technique to enable block-wise computation of attention.

  • Instead of computing the softmax function over the entire attention scores matrix (S)(S)(S), FlashAttention-2 computes the softmax function for each block independently.

  • It maintains a running maximum and a running sum of exponentials for each block, which allows for the correct computation of attention probabilities without materialising the entire SSS and PPP matrices.

  • The online softmax technique ensures that the output of each block is correctly scaled, resulting in the same output as the standard attention mechanism, but with reduced memory requirements.

Recomputation and Gradient Checkpointing

  • In the backward pass, FlashAttention-2 avoids storing the large intermediate matrices SSS and PPP by recomputing their values once the blocks of inputs (Q,K,V)(Q, K, V)(Q,K,V) are loaded into on-chip SRAM.

  • By not having to save the SSS and PPP matrices, which have a size of N×NN×NN×N (where NNN is the sequence length), FlashAttention-2 achieves a significant memory saving, reducing the memory requirement from quadratic to linear in sequence length.

  • The recomputation strategy, also known as gradient checkpointing, trades off some additional computation for reduced memory usage, enabling the processing of longer sequences with limited memory resources.

Improved Parallelism

  • FlashAttention-2 enhances parallelism by processing different blocks of the attention computation on different thread blocks of the GPU.

  • In the forward pass, FlashAttention-2 parallelises the computation across the sequence length dimension, in addition to the batch and head dimensions. Each thread block handles a block of rows of the attention matrix, allowing for efficient utilisation of GPU resources, especially when dealing with long sequences and small batch sizes.

  • In the backward pass, FlashAttention-2 parallelises the computation across the sequence length dimension as well, with each thread block handling a block of columns of the attention matrix. This parallelisation scheme improves the utilisation of GPU resources and reduces the overall runtime.

Optimised Work Partitioning

  • FlashAttention-2 optimises the work partitioning between different warps within each thread block to reduce communication and shared memory reads/writes.

  • In the forward pass, FlashAttention-2 splits the query matrix (Q)(Q)(Q) across different warps, while keeping the key (K) (K)(K) and value (V) (V)(V) matrices accessible by all warps. This approach eliminates the need for communication between warps and reduces shared memory reads/writes compared to the "split-K" scheme used in the original FlashAttention.

  • In the backward pass, FlashAttention-2 also avoids the "split-K" scheme and partitions the work between warps to minimise shared memory reads/writes, leading to improved performance.

What is the "Split K" fix

  • In the context of attention computation, "split-K" refers to a scheme where the key matrix (K)(K)(K) is split across different warps (groups of threads) within a thread block.

  • In the original FlashAttention implementation, the key matrix (K)(K)(K) and value matrix (V)(V)(V) are divided and assigned to different warps, while the query matrix (Q) (Q) (Q)is accessible by all warps.

  • After each warp performs a matrix multiplication between its assigned portion of QQQ and KKK to compute a slice of the attention scores (QK⊤)(QK^⊤)(QK⊤), the warps need to communicate and exchange intermediate results through shared memory.

  • This communication and shared memory reads/writes introduce overhead and can slow down the computation.

  • FlashAttention-2 addresses this issue by using a different work partitioning scheme in the forward pass, where the query matrix QQQis split across warps instead of the key matrix KKK.

Empirical Validation and Performance Gains

The FlashAttention-2 paper provides detailed benchmarks and performance comparisons that demonstrate the significant improvements achieved by the optimised attention mechanism.

Attention Computation Speed

FlashAttention-2 showcases remarkable speed improvements in attention computation compared to previous implementations.

The paper presents comprehensive benchmarks on NVIDIA A100 GPUs, comparing FlashAttention-2 with PyTorch's standard attention implementation, the original FlashAttention, and FlashAttention implemented in the Triton language.

FlashAttention-2 achieves up to 3 times faster attention computation compared to the original FlashAttention and up to 10 times faster than PyTorch's standard attention implementation.

For example, on sequences of length 2048 with a head dimension of 128, FlashAttention-2 attains an impressive 187 TFLOPs/s in the forward pass, surpassing FlashAttention's 120 TFLOPs/s and PyTorch's 60 TFLOPs/s.

These speed improvements are consistent across various sequence lengths and head dimensions, showcasing FlashAttention-2's ability to efficiently utilise GPU resources and minimise memory bottlenecks.

Backward Pass Efficiency

The backward pass in attention mechanisms is known to be computationally challenging due to the dependencies between gradients and the need to store intermediate values.

However, FlashAttention-2 tackles this challenge head-on and achieves significant speedups in the backward pass as well. The paper presents detailed benchmarks comparing the backward pass performance of FlashAttention-2 with other implementations.

On A100 GPUs, FlashAttention-2 demonstrates up to 2 times faster backward pass compared to the original FlashAttention

It also outperforms PyTorch's standard attention implementation by a significant margin.

For instance, on sequences of length 4096 with a head dimension of 128, FlashAttention-2 achieves 165 TFLOPs/s in the backward pass, while FlashAttention reaches 80 TFLOPs/s and PyTorch's implementation obtains only 45 TFLOPs/s.

This showcases FlashAttention-2's ability to efficiently compute gradients and optimise memory usage in the backward pass.

The implications of FlashAttention-2's performance improvements are far-reaching. It enables faster training and inference of Transformer models, opening up possibilities for handling longer sequences and more complex tasks.

The Impact of FlashAttention-2

FlashAttention-2 represented a significant advancement in the field of efficient Transformer models, particularly in addressing the limitations of the standard attention mechanism.

FlashAttention-2 achieves remarkable improvements in both speed and memory efficiency.

The empirical results demonstrate that FlashAttention-2 can significantly outperform previous implementations, including the original FlashAttention and standard attention implementations in PyTorch.

The implications of FlashAttention-2 extend beyond just raw performance gains. It opens up new possibilities for training and deploying large Transformer models capable of handling longer sequences and more complex tasks.

As the demand for AI models that can understand and generate longer contexts continues to grow, FlashAttention-2 provides a promising solution to address the challenges associated with scaling attention mechanisms.

Paper References

Ainslie et al. (2023) and Shazeer (2019) discuss techniques for training and optimising Transformer models, such as multi-query attention and fast transformer decoding.

Several papers, including Beltagy et al. (2020), Katharopoulos et al. (2020), Kitaev et al. (2020), Roy et al. (2021), Wang et al. (2020), and Zaheer et al. (2020), propose alternative attention mechanisms to address the quadratic complexity of standard attention. These techniques include sparse attention, linearized attention, and hierarchical attention.

Chen et al. (2021) and Choromanski et al. (2020) introduce novel attention mechanisms, such as Scatterbrain and Performers, to improve the efficiency of Transformers.

Dao et al. (2022) propose the original FlashAttention algorithm, which serves as the foundation for FlashAttention-2.

Jia and Van Sandt (2021) and Jia et al. (2018) provide insights into the architecture and performance characteristics of NVIDIA GPUs, which are relevant for optimising attention computations.

Lefaudeux et al. (2022) introduce the xformers library, which includes implementations of various efficient attention mechanisms.

Milakov and Gimelshein (2018) and Rabe and Staats (2021) discuss techniques for optimising the softmax computation in attention mechanisms.

OpenAI (2023) presents the technical details of the GPT-4 model, which demonstrates the importance of scaling Transformers to longer sequences.

Shoeybi et al. (2019) introduce Megatron-LM, a large-scale language model trained using model parallelism, highlighting the need for efficient attention mechanisms in training billion-parameter models.

Tillet et al. (2019) propose Triton, an intermediate language and compiler for tiled neural network computations, which is used in some implementations of FlashAttention.

Vaswani et al. (2017) introduce the original Transformer architecture and the standard attention mechanism, which serves as the basis for subsequent research on efficient attention.

From the above image you can see that SRAM is faster than the HBM, but SRAM has very much limited memory as its fast cache memory. In the above image you can compare the speed of HBM and SRAM. HBM is used to store Tensors, while SRAM is used to perform compute operations on those Tensors.
You can see the concept here - tiling means breaking or chopping the the matrix into multiple block and then performing the computation in SRAM instead of passing a complete N*N matrix from HBM to SRAM. This is because SRAM is very limited in capacity - the entire matrix cannot be loaded into SRAM.
From the paper: In the forward pass (left), we parallelise the workers (thread blocks) where each worker takes care of a block of rows of the attention matrix. In the backward pass (right), each worker takes care of a block of columns of the attention matrix
Work partitioning between different warps in the forward pass
Page cover image
LogoFlashAttention-2: Faster Attention with Better Parallelism and...arXiv.org
FlashAttention-2, developed by Tri Dao from the departments of Computer Science at Princeton University and Stanford University
Great interview with the inventor of Flash Attention 1 and 2