FiDO: Fusion-in-Decoder optimised for stronger performance and faster inference
Google Research, December 2022
Last updated
Copyright Continuum Labs - 2023
Google Research, December 2022
Last updated
FiD is a retrieval-augmented language model that integrates external text sources to enhance the model's performance on knowledge-intensive natural language processing tasks.
While FiD has set new benchmarks, the original architecture, borrowed with minimal changes from the T5 model, isn't fully optimized for retrieval-augmented tasks.
Encoder-Decoder Resource Allocation: In FiD, a significant discrepancy exists in the resource allocation between the encoder and decoder. The encoder consumes a vast majority of Floating Point Operations (FLOPs) and training time. However, during inference, the decoder, not the encoder, becomes the primary computational bottleneck, primarily due to memory bandwidth constraints.
Inference Time Bottleneck: Despite the encoder's resource intensity, the decoder's role in assimilating and reasoning over information from multiple passages makes it a critical point of computational expense during inference, largely due to its complex multi-head cross-attention mechanism over extensive input sequences.
FLOPs Calculation: In the context of Transformer models, FLOPs are calculated based on operations in various components, primarily MLP (multi-layer perceptron) layers, attention projection layers, and attention operations. For simplicity, this analysis focuses on multiplication operations, a common approach to estimate computational complexity.
Encoder and Decoder FLOPs: The FLOPs in the encoder are computed considering the model dimension (d), the total number of tokens across all passages (ns), the number of tokens in a single passage (np), the number of target tokens (nt), and the number of layers (L).
Given that the MLP dimension is set to 4d, the encoder's FLOPs are predominantly due to MLP and QKVO (query, key, value, and output) projections in attention mechanisms. The decoder's FLOPs include additional cross-attention layers, crucial for integrating retrieved information.
Memory Bandwidth Constraints: Despite the encoder consuming more FLOPs, the decoder is the primary bottleneck during inference, mainly due to memory bandwidth constraints, particularly in cross-attention mechanisms. This is because the decoder has to manage large input sequences, making it memory-bandwidth bound.
Proposed Architectural Changes: To optimise FiD, the paper suggests reducing the number of cross-attention layers in the decoder and replacing multi-head attention with multi-query attention. These changes aim to reduce the decoder's memory bandwidth requirements, thereby speeding up inference.
Operational Intensity: The concept of operational intensity, which is the ratio of compute operations to memory operations, plays a crucial role in determining the actual computational throughput. High operational intensity is crucial for efficient performance on modern hardware like GPUs or TPUs. The proposed optimisations in FiD aim to increase operational intensity, particularly in the decoder.
Inference Efficiency: The optimisations lead to a significant improvement in inference speed (up to 7x faster), allowing the use of a larger decoder without a substantial increase in computational cost.
Performance Gains: These architectural changes not only enhance inference efficiency but also improve model performance across various knowledge-intensive tasks, as demonstrated in benchmark results for question-answering datasets.
The authors propose FiDO (Fusion-in-Decoder Optimized), an evolved architecture designed to address the inefficiencies identified in FiD:
Reducing Cross-Attention Layers: By minimising the number of cross-attention layers in the decoder, FiDO significantly cuts down on computational costs while maintaining performance levels. This adjustment directly targets the decoder's memory bandwidth constraints, optimizing its operational efficiency.
Introducing Multi-Query Attention: Replacing multi-head attention with multi-query attention further alleviates the memory-bandwidth bottleneck. This architectural tweak ensures that the decoder's inference becomes substantially faster, aligning the inference time more closely with the computational FLOP distribution between the encoder and decoder.
Decoder Scaling (DecXL): An intriguing aspect of FiDO is its approach to scale up the decoder massively while keeping the encoder relatively smaller. This rebalancing aims to enhance the decoder's ability to synthesise and reason over the retrieved information, fostering a deeper understanding and generation of outputs.
Performance Metrics: The proposed FiDO architecture demonstrates a notable improvement in inference speed—up to 7x faster—while also showing strong performance enhancements across various question-answering benchmarks like Natural Questions, TriviaQA, and WebQuestions.
Inference Efficiency: The modifications lead to a more efficient utilisation of computational resources, particularly during inference, aligning the computational intensity with the actual utility derived from each component of the model.
FiDO represents a stride forward in optimising retrieval-augmented language models.
By addressing the architectural inefficiencies in FiD, the researchers not only enhance the model's inference speed and computational economy but also set a new standard for performance in knowledge-intensive tasks.
This work underscores the importance of architectural considerations in the development of AI systems, highlighting how thoughtful modifications can lead to substantial gains in efficiency and effectiveness.