Prefix-Tuning: Optimizing Continuous Prompts for Generation
Last updated
Copyright Continuum Labs - 2023
Last updated
This highly cited January 2021 paper introduced a new technique for efficiently fine-tuning language models (LMs) called prefix-tuning.
This method addressed the challenges of efficiently adapting LMs to specific tasks while maintaining their generalisation capabilities and minimising the storage requirements for task-specific parameters.
Prefix tuning adapts pre-trained language models to specific tasks without modifying the original model's weights.
Prefix tuning draws inspiration from the concept of prompting, where task instructions and examples are prepended to the input to steer the LM's generation. However, instead of using discrete tokens, prefix tuning uses a continuous prefix vector.
Prefix-tuning involves prepending a sequence of continuous task-specific vectors, called a prefix, to the input of the LM.
The Transformer can attend to these prefix vectors as if they were a sequence of "virtual tokens". Unlike prompting, the prefix vectors do not correspond to real tokens but are learned during training.
In prefix tuning, we create a tensor called a "soft prompt" for each transformer block in the model.
This soft prompt is a set of learnable parameters that are specific to the task we want to adapt the model for.
Before using the soft prompt, it is passed through a set of fully connected layers.
These layers transform the soft prompt into a suitable representation that can be combined with the main input to the transformer block.
The transformed soft prompt is then concatenated with the main input to the transformer block.
This concatenation happens along the sequence length dimension, meaning the soft prompt is added as additional tokens at the beginning of the input sequence.
The modified input, which now includes the soft prompt, is passed through the standard transformer block operations.
These operations include self-attention, layer normalisation, and feed-forward neural network layers, along with residual connections.
The transformer block processes the input as usual, but now it also takes into account the information provided by the soft prompt.
During training, only the soft prompts are updated, while the pre-trained model's weights remain frozen.
The model learns to adapt to the specific task by adjusting the soft prompts based on the task-specific training data.
By keeping the original model's weights unchanged, prefix tuning allows for efficient adaptation without the need for fine-tuning the entire model.
The key idea behind prefix tuning is that by adding task-specific soft prompts to each transformer block, the model can learn to condition its behavior based on the prompts.
The soft prompts act as a "prefix" that guides the model's attention and computation towards the relevant information for the task at hand.
The prefix is a sequence of continuous vectors that are prepended to the input sequence.
The length of the prefix is a hyperparameter that you can choose based on the complexity of the target personality and the available computational resources. Common prefix lengths range from 10 to 50 tokens.
The prefix is initialized as a trainable matrix of size , where is the number of prefix tokens and is the size of the model's word embeddings.
Each row of the prefix matrix corresponds to a prefix token, and the values in that row represent the embedding of that token.
The prefix matrix is randomly initialised or initialised using the activations of real words that are relevant to the target personality. Initialising with relevant words can provide a good starting point for the prefix and potentially speed up convergence during training.
In the Transformer architecture, the input sequence is typically represented as a matrix of word embeddings, where each row corresponds to a token in the sequence.
To incorporate the prefix, you concatenate the prefix matrix with the input embeddings matrix along the sequence dimension (usually axis 1). This results in a new input matrix , where is the original input embeddings matrix.
During the forward pass, the concatenated matrix is passed through the Transformer layers, which include self-attention and feed-forward layers.
The self-attention mechanism in the Transformer layers allows the prefix tokens to attend to and influence the representations of the input tokens, effectively steering the model's behavior.
Importantly, during training, only the prefix matrix is updated, while the pre-trained model's parameters (i.e., the weight matrices in the Transformer layers) remain frozen. This ensures that the prefix adapts to the target personality while preserving the general language understanding captured by the pre-trained model.
To train the prefix, you use a prepared dataset that consists of input-output pairs, where the input is a prompt or context and the output is the corresponding response that reflects the desired personality.
During training, you feed the input sequence through the modified model architecture, which includes the prefix matrix concatenated with the input embeddings.
The model generates a probability distribution over the vocabulary for each position in the output sequence, and you compute a language modeling loss (e.g., cross-entropy loss) between the predicted probabilities and the true output tokens.
The gradients of the loss with respect to the prefix matrix are computed using backpropagation, and the prefix matrix is updated using an optimisation algorithm like Adam.
The pre-trained model's parameters remain fixed during training, so only the prefix matrix is updated to minimise the language modeling loss.
You can experiment with different hyperparameters such as the learning rate, batch size, and number of training epochs to find the optimal configuration that achieves the best performance on a validation set.
By training the prefix matrix while keeping the pre-trained model's parameters frozen, you allow the prefix to adapt to the target personality while leveraging the general language understanding captured by the pre-trained model.
The prefix acts as a "soft prompt" that steers the model's behavior towards generating responses that align with the desired personality.
It allows for efficient adaptation of pre-trained models to new tasks without modifying the original model's weights.
It requires fewer trainable parameters compared to fine-tuning the entire model, making it more computationally efficient.
It can be applied to any pre-trained transformer-based model without the need for task-specific architectures.
The diagram below demonstrates:
Prefix-tuning allows for the independent training of tasks, enabling scalable personalisation without data cross-contamination.
Each user's data can be isolated, and a personalised prefix can be trained for each user, ensuring privacy and modularity. The independence of tasks also enables efficient batching across users and the creation of ensembles of multiple prefixes trained on the same task.
The paper demonstrates the effectiveness of prefix-tuning through extensive experiments on various natural language generation tasks, such as table-to-text generation and summarisation.
The results show that prefix-tuning outperforms other lightweight fine-tuning methods, such as adapter-tuning, while using substantially fewer parameters. It achieves performance comparable to full fine-tuning, especially in low-data regimes and when generalising to unseen topics.
The authors also explore the impact of prefix length on the model's performance, revealing that there is an optimal prefix length for each task.
Increasing the prefix length up to a certain threshold improves performance, but further increases lead to diminishing returns and potential overfitting.
Furthermore, the paper compares prefix-tuning with an embedding-only approach, where only the embeddings of the virtual tokens are optimised.
The results demonstrate that the embedding-only approach lacks the expressiveness necessary to achieve optimal performance, highlighting the importance of optimising the prefix vectors across all layers of the LM.
The discussion section of the paper emphasises the potential of prefix-tuning for real-world applications, particularly in scenarios requiring personalisation, privacy, efficiency, and scalability.
The modularity and independence of tasks make prefix-tuning suitable for enterprise-level applications where customer-specific interactions and computational efficiency are crucial.
In conclusion, the paper introduces a powerful and efficient method for fine-tuning large language models. By optimising continuous prompts in the form of prefix vectors, prefix-tuning achieves strong performance while significantly reducing the storage requirements for task-specific adaptations.
The technique's modularity, privacy-preserving nature, and scalability make it particularly suitable for real-world applications and enterprise-level deployments.