Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
Ofir Press, Noah A. Smith, Mike Lewis
Last updated
Copyright Continuum Labs - 2023
Ofir Press, Noah A. Smith, Mike Lewis
Last updated
The paper addresses a fundamental question in transformer-based language models: how can a model extrapolate to longer sequences during inference than it was trained on?
The authors show that the choice of position representation method plays a crucial role in enabling extrapolation. They introduce a new method called Attention with Linear Biases (ALiBi), which is simpler and more efficient than existing approaches.
The paper first examines the extrapolation capabilities of current position representation methods:
Sinusoidal Position Embeddings: Despite the theoretical ability to extrapolate, in practice, sinusoidal embeddings have very limited extrapolation capabilities.
Rotary Position Embeddings: While rotary embeddings improve upon sinusoidal embeddings, they still do not achieve satisfactory extrapolation results and come with computational costs.
T5 Bias: The T5 bias method leads to better extrapolation than sinusoidal and rotary embeddings but is computationally expensive.
The authors conclude that extrapolation ability depends heavily on the position embedding method, and current approaches do not extrapolate efficiently.
ALiBi is a simple and efficient method for enabling extrapolation in transformer-based language models. Instead of adding position embeddings to word embeddings, ALiBi biases the query-key attention scores with a penalty proportional to their distance.
Key points about ALiBi:
No position embeddings are added at any point in the network.
A static, non-learned bias is added to the query-key dot product.
The bias is a linearly decreasing penalty based on the distance between the query and key.
Different attention heads have different slope values for the linear bias, allowing them to capture different ranges of dependencies.
ALiBi is easy to implement with minimal changes to existing transformer code.
The authors evaluate ALiBi on various datasets and model sizes, demonstrating its effectiveness and efficiency.
WikiText-103: ALiBi models outperform sinusoidal models even when trained on fewer tokens. Models trained on shorter sequences (e.g., 512 tokens) can extrapolate to much longer sequences (e.g., 3072 tokens) and achieve better perplexity than sinusoidal models trained on the longer sequences.
Toronto BookCorpus: ALiBi's performance gains transfer to a different domain (books) without modifying the hyperparameters, demonstrating its robustness.
Large-scale experiments: ALiBi achieves similar perplexity scores to sinusoidal models while running faster and using less memory when trained on a 1.3B parameter model with a much larger dataset (461 GB).
In summary, the Attention with Linear Biases (ALiBi) paper introduces a simple and efficient method for enabling extrapolation in transformer-based language models. By biasing query-key attention scores with a distance-dependent penalty, ALiBi allows models trained on shorter sequences to extrapolate to longer sequences during inference, outperforming existing position representation methods. The method is easy to implement, computationally efficient, and generalizes well across different domains and model sizes.
Sinusoidal position embeddings are a method used in transformer-based models to encode the position or order of tokens in a sequence. In transformers, positional information is essential because the models do not have recurrent connections like RNNs, which inherently capture the order of input tokens.
Sinusoidal position embeddings assign each position in a sequence a unique embedding vector based on sinusoidal functions. The position embedding vector captures both the relative and absolute positions of tokens in the sequence. The sinusoidal functions ensure that the embeddings have distinct values for different positions while still preserving some notion of similarity between positions that are close together.
The sinusoidal position embeddings are added to the input embeddings of tokens before being passed through the transformer layers. By incorporating the positional information into the model's input, transformers can learn to attend to different positions in the sequence and capture dependencies based on their relative positions.
However, as mentioned in the previous text, sinusoidal position embeddings have limitations when it comes to extrapolation, i.e., performing well on sequences longer than those encountered during training. Alternative methods, such as Attention with Linear Biases (ALiBi), have been proposed to address this limitation and improve the ability of models to extrapolate to longer sequences.
The term "sinusoidal" is derived from the word "sine," which refers to the trigonometric function sine. In mathematics, a sinusoidal function or waveform follows the shape of a sine or cosine curve.
The sine function, denoted as sin(x), is a periodic function that oscillates between -1 and 1. It represents the relationship between the angles of a right triangle and the ratios of its sides. The sine function can be extended to real numbers and used to model various natural phenomena that exhibit periodic behavior.
The use of sinusoidal functions in the context of position embeddings in transformers is inspired by the need to capture the notion of position and sequence order in a continuous and smooth manner. Sinusoidal functions offer several desirable properties for this purpose:
Periodicity: Sinusoidal functions repeat their values over a specific interval, providing a cyclical pattern that can be used to represent the position in a sequence.
Smoothness: Sinusoidal functions exhibit a smooth, continuous curve, which helps in preserving some degree of similarity between neighboring positions.
Distinctness: Sinusoidal functions generate different values for different positions, ensuring that each position has a unique embedding vector.
By using sinusoidal functions to encode position information, transformers can leverage the periodic nature of these functions to model the sequential relationships between tokens in a sequence. This enables the model to capture positional dependencies and effectively process sequential data.