Flash Attention
Last updated
Copyright Continuum Labs - 2023
Last updated
The authors of this June 2022 paper propose FlashAttention, an approach to computing exact attention that optimises memory usage and computational efficiency by leveraging the memory hierarchy of modern hardware.
The key idea behind FlashAttention is to exploit the fast on-chip memory (SRAM) of GPUs to store intermediate computations and minimise data movement between different levels of the memory hierarchy.
By carefully designing the attention computation to fit within the SRAM and optimising the data layout, FlashAttention achieves significant speedups and memory savings compared to traditional attention implementations.
One of the strengths of this paper is the comprehensive analysis of the memory hierarchy and bandwidth considerations in the context of attention computation.
The authors provide a detailed breakdown of the memory requirements and data movement patterns for different attention variants, such as scaled dot-product attention and masked attention.
They demonstrate that by carefully managing data movement and maximising the utilisation of fast on-chip memory, FlashAttention can achieve up to 2.4 times speedup and 7.6 times memory reduction compared to the standard PyTorch implementation.
The evaluation of FlashAttention on the GPT-2 language model showcases its practical impact.
The authors show that FlashAttention can significantly reduce the memory footprint and computational time of attention layers, enabling the training and inference of larger models with limited GPU memory.
This is particularly relevant in the context of recent advancements in large-scale language models, where the attention mechanism is a major bottleneck in terms of memory and computational efficiency.
Attention mechanisms are crucial for these models as they help determine which parts of the input data the model should focus on. However, they can be resource-intensive, requiring significant memory and processing power.
The paper introduces a novel approach to managing the memory hierarchy during the attention process, focusing on three main types of memory:
SRAM (Static Random-Access Memory): Fast and located on the GPU, but limited in size (19 TB/s bandwidth, 20 MB size).
HBM (High Bandwidth Memory): Also on the GPU, slower than SRAM but with a larger capacity (1.5 TB/s bandwidth, 40 GB size).
DRAM (Dynamic Random-Access Memory): Located on the CPU, with the largest capacity but the slowest bandwidth (12.8 GB/s, over 1 TB size).
FlashAttention optimises the use of these memory types by carefully managing where and how data is stored and accessed during the computation process.
It introduces a method to minimise the need to access slower memory types (like DRAM) by efficiently using SRAM and HBM.
This is achieved through smart data copying and computation strategies that reduce the need for data movement across different memory types, significantly speeding up the attention computation process.
The technical details involve breaking down the computation into smaller blocks that can be efficiently processed within the faster, but smaller, memory spaces (SRAM), and then combining the results. This process involves several steps, including:
Copying necessary data blocks to SRAM.
Computing attention blocks within SRAM.
Efficiently outputting results to HBM for further processing or storage.
The FlashAttention method shows significant improvements in the time it takes to compute the attention mechanism in GPT-2. This not only means faster training times for these large models but also opens up possibilities for their deployment in environments where computational resources are limited.