DSPy: In-Context Learning for Extreme Multi-Label Classification
Last updated
Copyright Continuum Labs - 2023
Last updated
This January 2024 paper addresses the challenges in solving multi-label classification problems with thousands of classes using in-context learning alone.
Language models (LMs) often lack prior knowledge about the specific classes and demonstrating every class in a prompt is impractical.
The authors propose a general program named Infer–Retrieve–Rank (IReRa) to efficiently tackle such problems.
Implemented using the DSPy programming model, IReRa defines multi-step interactions between LMs and retrievers. DSPy optimizers tune the program towards specific datasets using only a few labeled examples.
The proposed solution achieves state-of-the-art results across various benchmarks without requiring finetuning, prompt engineering, or extensive labeled data. The program is highly adaptable to new tasks and datasets, demonstrating competitive performance even in benchmarks with vastly different characteristics.
In the context of this paper, "classes" refer to the different categories or labels that an item can belong to in a multi-label classification task.
Here’s a simplified explanation:
Multi-label Classification: This is a type of problem where each item can be assigned more than one label or category. For example, an email might be classified as both "important" and "work-related."
Classes: These are the possible labels or categories that an item can be assigned to. In this case, there can be upwards of 10,000 different classes. For example, if you were classifying job descriptions, classes might include labels like "software development," "data analysis," "project management," etc.
Extreme Multi-label Classification (XMC): When there are a very large number of possible classes, it becomes an extreme classification problem. Handling such a large number of classes is challenging because a language model needs to understand and distinguish between all these potential categories.
Lack of Prior Knowledge: Language models might not have prior knowledge about the specific classes, especially when there are thousands of them.
Infeasibility of Demonstration: It’s generally impractical to demonstrate every class in a prompt because of the sheer number of classes.
Complex Configuration: Existing methods often require complex configurations with multiple LM calls, prompts, and hyperparameters, making it difficult to apply them to new datasets or LMs.
To address these challenges, the authors propose a method called Infer-Retrieve-Rank (IReRa), which involves the following steps:
Infer: An in-context learning module processes the input and predicts a set of applicable terms (queries).
Retrieve: These predicted terms are then related to the actual label space using a frozen retriever.
Rank: A second in-context learning module re-ranks the retrieved labels.
Minimal Prompt: A minimal prompt is used to bootstrap the process, making it easier to configure and adapt to new tasks.
Zero-shot Teacher LM: This model generates initial demonstrations to optimise the few-shot Student LM.
Efficiency: The approach uses only about 50 labeled examples and minimal training data to achieve state-of-the-art results.
DSPy Programming Model: The DSPy model allows for the separate specification and optimisation of the program, making it flexible and generalisable.
No Finetuning Required: The method does not require extensive finetuning for different datasets.
Ease of Adaptation: Adapting to new datasets involves simple steps like writing a new prompt and configuring the LMs.
State-of-the-art Performance: The proposed method achieves state-of-the-art results in various benchmarks with minimal labeled data and no prompt engineering.
The related work section outlines previous approaches and methods used to tackle extreme multi-label classification (XMC) problems and compares them with the proposed Infer-Retrieve-Rank method.
Finetuning Specialized Retrievers or Binary Classifiers:
Specialized Retrievers: These are finetuned over the label space to efficiently retrieve relevant labels.
Binary Classifiers: One binary classifier is finetuned per class to decide whether an input belongs to a specific class or not.
Drawbacks: Both approaches require a significant amount of labeled data, as each class needs at least a few labeled examples to train effectively.
Distant Supervision:
Purpose: Used to avoid manual data labeling.
Method: Employs heuristics or external resources to automatically label data. This approach can provide initial labels without manual annotation.
Synthetic Data Bootstrapping:
Method: Large Language Models (LLMs) are used to generate synthetic data to augment the training dataset.
Examples: Decorte et al. (2023), Clavié and Soulié (2023), and De Raedt et al. (2023) used this method to bootstrap synthetic data.
Finetuning Retrievers on Adjacent Problems:
Method: Retrievers are finetuned on related problems where labeled data is available.
Example: Remy et al. (2022) used this approach to leverage available data from adjacent problems to improve retriever performance.
Reranking with Additional LLM Calls:
Method: An additional LLM call is made at inference time to rerank a list of candidate labels, aiming to boost performance.
Example: Clavié and Soulié (2023) employed this technique to enhance label accuracy.
Inference-time Multiple LLM Calls:
Method: Zhu and Zamani (2023) utilized multiple GPT-3.5 calls combined with retrieval to bootstrap synthetic prompts per input, infer labels, and rerank them.
Evaluation: This approach was evaluated on two recommendation tasks where the input and output documents were of the same type.
The authors compare their Infer-Retrieve-Rank program with the aforementioned methods, highlighting several key differences and advantages:
Efficiency:
Minimal Data Requirement: Infer-Retrieve-Rank can achieve state-of-the-art performance using only approximately 50 labeled examples, making it much more data-efficient compared to other methods that require extensive labeled data.
Few LLM Calls: Unlike methods requiring numerous LLM calls per input, Infer-Retrieve-Rank minimizes the number of LLM calls, enhancing efficiency.
No Finetuning Required:
The proposed method does not rely on finetuning the LMs or retrievers, simplifying the development and deployment process.
Modular and Declarative Program Logic:
Flexibility: The program logic is defined in a modular and declarative manner, allowing it to be seamlessly applied to different benchmarks with a minimal seed-prompt.
Automatic Optimization: The DSPy programming model handles optimisation automatically, significantly reducing the need for iterative prompt engineering. This optimisation can be completed in as little as ten minutes.
Configurable Components:
Adaptability: The choice of LMs and retrievers can be configured, ensuring the relevance and potential enhancement when stronger components become available.
Single Seed-prompt: The method requires at most one seed-prompt per task and in-context module, simplifying the setup process.
The provided code block implements the Infer-Retrieve-Rank program using the DSPy framework.
This program is designed to tackle extreme multi-label classification tasks, which involve assigning multiple labels to a given input from a very large set of possible labels.
The Infer-Retrieve-Rank approach uses language models (LMs) and retrievers in a modular and efficient manner to predict, retrieve, and rerank labels based on the input data.
Below is a detailed breakdown and explanation of each part of the code:
Class Definition:
This line defines a new class named InferRetrieveRank
that inherits from dspy.Module
. This class represents the Infer-Retrieve-Rank program.
Initialization Method:
Line 2: Defines the __init__
method which initializes the object.
Lines 4-5: The infer
and rank
attributes are initialized using the dspy.ChainOfThought
class, which takes infer_sig
and rank_sig
as signatures. These signatures likely define the configuration or parameters for the language models (LMs) used in the infer and rank steps.
Line 6: The retrieve
attribute is initialized with the retriever module passed as an argument.
Forward Method Definition:
This line defines the forward
method, which is the core logic for the Infer-Retrieve-Rank program. It takes a text input and returns a prediction.
Inference Step:
Line 10: Uses the infer
LM to process the input text
and generate predictions (preds
). The completions.labels
suggests that the predictions are labels generated by the LM.
Parsing LM Output:
Line 13: The raw predictions are parsed and extracted into a format suitable for retrieval. This involves cleaning and structuring the LM output into a list of labels.
Retrieval Step:
Line 16: The parsed labels (preds
) are used as queries to the retrieve
module, which returns a ranked list of labels based on their similarity to the queries.
Reranking Step:
Line 19: The initial list of labels is reranked using the rank
LM, which takes both the original input text
and the retrieved labels to produce a final ranked list of labels.
Return Prediction:
Line 21: The final ranked labels are returned as a dspy.Prediction
object.
The provided code snippets define seed-prompts for the Infer and Rank modules for different datasets using the DSPy framework.
These prompts are organised using the DSPy Signature abstraction, which specifies the structure and behavior of each in-context learning module.
The BioDEX Infer Module seed-prompt is constructed to address the following key challenges:
Identification of Adverse Drug Reactions
The primary goal is to accurately identify adverse drug reactions mentioned in medical article snippets. By clearly defining the input and output fields, the module is guided to focus on extracting these reactions from the provided text.
Structured Data Processing
The use of dspy.InputField
and dspy.OutputField
ensures that the data is processed in a structured manner. This structuring helps maintain consistency and accuracy in identifying and formatting the adverse drug reactions.
Efficiency and Accuracy
By providing a clear task description and structured input/output fields, the seed-prompt ensures that the Infer module operates efficiently. The module can quickly process the input text and accurately identify the relevant adverse drug reactions without requiring extensive manual intervention.
Adaptability
The modular approach facilitated by the DSPy framework allows for easy adaptation of the Infer module to different datasets or slightly varied tasks. By changing the input and output field definitions, the module can be tailored to new requirements, demonstrating flexibility and scalability.
Class Definition
Defines a new class BiodexInferSignature
that inherits from dspy.Signature
. This class specifies the signature for the Infer module on the BioDEX dataset.
Docstring
Provides a task description. This docstring explains the task: given a snippet from a medical article, the module should identify and return the adverse drug reactions affecting the patient.
Input Field
Defines an input field with the prefix "Article:". This indicates that the input to the module will be a snippet from a medical article.
Output Field
Defines an output field with the prefix "Reactions:". The description specifies that the output should be a list of comma-separated adverse drug reactions.
The seed-prompt for the BioDEX Infer Module is designed to streamline the identification of adverse drug reactions from medical article snippets.
Using the DSPy framework, this seed-prompt employs a structured and declarative approach to define the behavior of the Infer module, ensuring efficient and accurate performance within the BioDEX dataset.
By explicitly outlining the input and output fields, it facilitates a clear and consistent processing pipeline, enabling the module to reliably extract relevant information.
The BioDEX Rank Module seed-prompt is constructed to address the following key challenges:
Selection of Relevant Adverse Drug Reactions
The primary goal is to accurately select the most relevant adverse drug reactions from a given list of options within a medical article snippet. By clearly defining the input and output fields, the module is guided to focus on picking the top 10 applicable reactions from the provided options, ensuring relevance and precision.
Structured Data Processing
The use of dspy.InputField
and dspy.OutputField
ensures that the data is processed in a structured manner. This structuring helps maintain consistency and accuracy in identifying and formatting the adverse drug reactions, facilitating a reliable extraction and ranking process.
Efficiency and Accuracy
By providing a clear task description and structured input/output fields, the seed-prompt ensures that the Rank module operates efficiently. The module can swiftly process the input text and accurately identify the most relevant adverse drug reactions from the options without requiring extensive manual intervention.
Adaptability
The modular approach facilitated by the DSPy framework allows for easy adaptation of the Rank module to different datasets or slightly varied tasks. By changing the input and output field definitions, the module can be tailored to new requirements, demonstrating flexibility and scalability.
Class Definition
Defines a new class BiodexRankSignature
that inherits from dspy.Signature
. This class specifies the signature for the Rank module on the BioDEX dataset.
Docstring
Provides a task description. This docstring explains that the task is to pick the 10 most applicable adverse reactions from a given list of options.
Input Fields
Defines an input field with the prefix "Article:". This indicates that the input to the module will be a snippet from a medical article.
Defines another input field with the prefix "Options:". The description specifies that this field will contain a list of comma-separated options to choose from.
Output Field
Defines an output field with the prefix "Reactions:". The description specifies that the output should be a list of comma-separated adverse drug reactions.
The seed-prompt for the BioDEX Rank Module is designed to assist in selecting the most relevant adverse drug reactions from a medical article snippet.
Using the DSPy framework, this seed-prompt defines the structure for both input and output fields, ensuring that the Rank module can accurately identify and rank the most applicable adverse reactions. This structured approach allows for efficient processing and ranking of labels within the context of the BioDEX dataset.
The BioDEX Infer and Rank Modules collaborate to efficiently and accurately identify and rank adverse drug reactions from medical article snippets.
Here’s how they interact and the workflow they follow:
BioDEX Infer Module
The BioDEX Infer Module is the first step in the process. Its primary function is to identify all adverse drug reactions mentioned in a given medical article snippet.
BioDEX Rank Module
Once the BioDEX Infer Module has identified all potential adverse drug reactions, the BioDEX Rank Module steps in to rank these reactions based on relevance. Here’s how it works:
Workflow Interaction
Initial Extraction:
The BioDEX Infer Module processes the medical article snippet to extract all mentioned adverse drug reactions. This initial step ensures that all potentially relevant reactions are identified.
Ranking and Selection:
The extracted reactions are then passed as options to the BioDEX Rank Module. The Rank Module evaluates these options against the same medical article snippet to select and rank the top 10 most applicable reactions.
Final Output:
The ranked reactions are then outputted by the Rank Module, providing a concise and prioritized list of the most relevant adverse drug reactions for the given medical article snippet.
The BioDEX Infer and Rank Modules work together to efficiently and accurately identify and rank adverse drug reactions from medical article snippets.
The Infer Module extracts all mentioned reactions, while the Rank Module selects and prioritizes the top 10 most relevant reactions.
This collaborative workflow ensures a structured, consistent, and accurate identification and ranking process, leveraging the DSPy framework's modular and declarative approach for optimal performance within the BioDEX dataset.
The ESCO Infer Module seed-prompt is constructed to address the following key challenges:
Identification of Relevant Skills:
The primary problem is to accurately identify all job skills mentioned in a job vacancy snippet. By clearly defining the input and output fields, the module is guided to focus on extracting skills from the provided text.
Structured Data Processing:
The use of dspy.InputField
and dspy.OutputField
ensures that the data is processed in a structured manner. This structuring helps in maintaining consistency and accuracy in identifying and formatting the job skills.
Efficiency and Accuracy:
By providing a clear task description and structured input/output fields, the seed-prompt ensures that the Infer module operates efficiently. The module can quickly process the input text and accurately identify the relevant skills without additional manual intervention.
Adaptability:
The modular approach facilitated by the DSPy framework allows for easy adaptation of the Infer module to different datasets or slightly varied tasks. By changing the input and output field definitions, the module can be tailored to new requirements.
This is accomplished using the DSPy framework, which provides a structured and clear approach to defining the behavior of the Infer module.
Here's how the code block works and what it aims to achieve:
Class Definition
Defines a new class EscoInferSignature
that inherits from dspy.Signature
. This class specifies the signature for the Infer module on the ESCO job vacancy dataset.
Docstring
Provides a task description. This docstring explains that the task is to identify and return all the ESCO job skills mentioned in a job vacancy snippet.
Input Field
Defines an input field with the prefix "Vacancy:". This indicates that the input to the module will be a snippet from a job vacancy.
Output Field
Defines an output field with the prefix "Skills:". The description specifies that the output should be a list of comma-separated ESCO skills.
The ESCO Rank Module seed-prompt is constructed to address the following key challenges:
Selection of Relevant Job Skills
The primary goal is to accurately select the most relevant job skills from a given list of options within a job vacancy snippet.
By clearly defining the input and output fields, the module is guided to focus on picking the top 10 applicable skills from the provided options, ensuring relevance and precision.
Structured Data Processing
The use of dspy.InputField
and dspy.OutputField
ensures that the data is processed in a structured manner. This structuring helps maintain consistency and accuracy in identifying and formatting the job skills, facilitating a reliable extraction and ranking process.
Efficiency and Accuracy
By providing a clear task description and structured input/output fields, the seed-prompt ensures that the Rank module operates efficiently. The module can swiftly process the input text and accurately identify the most relevant job skills from the options without requiring extensive manual intervention.
Adaptability
The modular approach facilitated by the DSPy framework allows for easy adaptation of the Rank module to different datasets or slightly varied tasks. By changing the input and output field definitions, the module can be tailored to new requirements, demonstrating flexibility and scalability.
Class Definition
Defines a new class EscoRankSignature
that inherits from dspy.Signature
. This class specifies the signature for the Rank module on the ESCO job vacancy dataset.
Docstring
Provides a task description. This docstring explains that the task is to pick the 10 most applicable ESCO job skills from a given list of options.
Input Fields
Defines an input field with the prefix "Vacancy:". This indicates that the input to the module will be a snippet from a job vacancy.
Defines another input field with the prefix "Options:". The description specifies that this field will contain a list of comma-separated options to choose from.
Output Field
Defines an output field with the prefix "Skills:". The description specifies that the output should be a list of comma-separated ESCO skills.
The seed-prompt for the ESCO Rank Module is designed to assist in selecting the most relevant job skills from a job vacancy snippet.
Using the DSPy framework, this seed-prompt defines the structure for both input and output fields, ensuring that the Rank module can accurately identify and rank the most applicable job skills. This structured approach allows for efficient processing and ranking of skills within the context of the ESCO dataset.
The ESCO Infer and Rank Modules are designed to work in tandem to efficiently and accurately identify and rank job skills from job vacancy snippets.
Here's how they interact and the workflow they follow:
ESCO Infer Module
The ESCO Infer Module is the first step in the process. Its primary function is to identify all the job skills mentioned in a given job vacancy snippet.
ESCO Rank Module
Once the ESCO Infer Module has identified all potential job skills, the ESCO Rank Module steps in to rank these skills based on relevance. Here's how it works:
Workflow Interaction
Initial Extraction:
The ESCO Infer Module processes the job vacancy snippet to extract all mentioned job skills. This initial step ensures that all potentially relevant skills are identified.
Ranking and Selection:
The extracted skills are then passed as options to the ESCO Rank Module. The Rank Module evaluates these options against the same job vacancy snippet to select and rank the top 10 most applicable skills.
Final Output:
The ranked skills are then outputted by the Rank Module, providing a concise and prioritized list of the most relevant job skills for the given job vacancy.
The ESCO Infer and Rank Modules work together to efficiently and accurately identify and rank job skills from job vacancy snippets. The Infer Module extracts all mentioned skills, while the Rank Module selects and prioritizes the top 10 most relevant skills.
This collaborative workflow ensures a structured, consistent, and accurate identification and ranking process, leveraging the DSPy framework's modular and declarative approach for optimal performance within the ESCO dataset.
Rank-Precision (RP) is a metric used to evaluate the quality of a ranked list of labels produced by the model. It measures how accurately the top-ranked labels match the true (gold) labels.
Rank-Precision at K (RP@K)
Rank-Precision at K (RP@K) is a specific version of rank-precision that evaluates the precision of the ranking up to the top K positions. Here's a breakdown of the metric:
K: This is the rank position up to which we measure the precision. For example, RP@5 would evaluate the precision of the top 5 ranked labels.
Rn: The total number of gold (true) labels for the n-th input. This varies for each input.
Rel(n,k): A relevance function that returns 1 if the k-th output label for input n is relevant (i.e., it matches one of the gold labels), and 0 otherwise.
Relevance and Precision: RP@K directly measures the relevance of the top K predictions, which is crucial for multi-label classification tasks where the goal is to accurately rank the most relevant labels at the top.
Adaptability to Varying Number of Labels: By considering both precision and recall depending on the relationship between K and Rn, RP@K provides a balanced evaluation metric that can adapt to varying numbers of gold labels across different inputs.
Comprehensive Evaluation: RP@K allows for a detailed and nuanced assessment of the model's performance, taking into account the ranked nature of the outputs and the varying importance of correctly ranking relevant labels.
In summary, the authors use RP@K to comprehensively evaluate the effectiveness of their Infer-Retrieve-Rank program in producing relevant and accurately ranked labels across different multi-label classification tasks.
The evaluation of the method and baselines was conducted on four extreme classification datasets, one in the biomedical field and three in the human resources field.
BioDEX (Biomedical Drug Event Extraction):
Source: The dataset is composed of biomedical papers that describe various adverse drug events and include expert-created labels for the specific types of medical reactions discussed.
Ontology Used: The labels are encoded using the MedDRA ontology (Medical Dictionary for Regulatory Activities), which is a standardized set containing approximately 24,300 medical reactions.
Data Characteristics:
Input Length: Inputs can be very long, with half of the inputs having upwards of approximately 20,000 characters.
Domain Knowledge: Requires biomedical domain knowledge to accurately infer the correct reactions, as only adverse reactions need to be reported, not all medical reactions.
Real-world Relevance: BioDEX models a crucial step in real-world drug safety pipelines.
Dataset Splits:
Training examples: 10
Validation examples: 50
Test examples: 250
Label Distribution:
Median number of labels per input: 3
95th percentile number of labels per input: 14
ESCO (European Skills, Competences, Qualifications and Occupations):
Source: The dataset comes from the ESCO ontology, which is managed by the European Commission Directorate-General for Employment, Social Affairs, and Inclusion. The ontology contains approximately 13,900 distinct concepts used to encode skills, competences, qualifications, and occupations.
Data Characteristics:
The datasets consist of snippets (typically one sentence) of online job vacancies in English with their relevant ESCO labels.
Sub-datasets:
HOUSE: Contains 262 test examples.
TECH: Contains 338 test examples.
TECHWOLF: Contains 326 test examples.
Dataset Splits:
HOUSE and TECH:
Training examples: 10 each from the validation sets
Validation examples: Remaining 51 (HOUSE) and 65 (TECH)
TECHWOLF: No specific validation or training split, uses the HOUSE training and validation split.
Label Distribution:
Median number of labels per input across these datasets: 1
95th percentile number of labels per input: 4
The evaluation uses datasets that cover both the biomedical domain and the human resources domain.
The BioDEX dataset focuses on identifying adverse drug reactions from lengthy biomedical articles, while the ESCO dataset focuses on identifying job skills from short job vacancy snippets.
The structured data processing and use of well-defined ontologies ensure consistency, accuracy, and real-world relevance, making these datasets suitable for evaluating extreme multi-label classification methods.
The results show the test performance of various models and tasks on the BioDEX and ESCO datasets, measured using the rank-precision (RP) metric at 5 and 10 (RP@5 and RP@10).
The analysis includes baseline methods, the proposed Infer–Retrieve–Rank method, and several fine-tuned systems from the literature.
Baseline Methods
Prior:
Description: This baseline uses the prior probability distribution of labels based on all training data to rank labels.
Performance: Generally low across all datasets, indicating the need for more sophisticated methods to achieve better accuracy.
Exact-Match:
Description: This method matches label names exactly within the input document.
Performance: Shows moderate improvement over the prior baseline, particularly effective in ESCO tasks, with RP@5 and RP@10 scores around 4-6.
Naive-Retrieve:
Description: Employs pre-trained retrievers (BioLORD for BioDEX and all-mpnet-base-v2 for ESCO tasks) to embed input documents and retrieve relevant labels.
Performance:
ESCO Tasks: Significantly outperforms the prior and exact-match baselines, with RP@5 and RP@10 scores ranging from 26 to 50.
BioDEX Task: Shows weaker performance due to the complexity and length of biomedical documents, with RP@5 and RP@10 scores around 11.
Configuration:
Infer Module: Uses Llama-2-7b-chat as the student LM and GPT-3.5-turbo as the teacher LM.
Rank Module: Uses GPT-4 as both the student and teacher.
Retrieval Module: Employs BioLORD for BioDEX and all-mpnet-base-v2 for ESCO tasks.
Performance:
ESCO Datasets (HOUSE, TECH, TECHWOLF):
Achieves state-of-the-art performance, with RP@5 and RP@10 scores between 56 and 71.
Significantly outperforms both baselines and fine-tuned systems, demonstrating the effectiveness of the combined modules.
BioDEX Dataset:
Competitive performance with RP@5 and RP@10 scores of 24.73 and 27.67, respectively.
While not surpassing the best fine-tuned systems, the method shows substantial gains over baselines and indicates potential for further improvement with optimization.
The results demonstrate the efficacy of the Infer–Retrieve–Rank approach in addressing extreme multi-label classification tasks, particularly for the ESCO datasets.
The method provides state-of-the-art performance with significantly less data and no finetuning, making it a cost-effective and scalable solution.
For BioDEX, while not the best, the approach shows promise and potential for further improvement with optimization. The clear advantages in efficiency, adaptability, and modular design underscore the robustness and versatility of the proposed method.
In this study, Infer–Retrieve–Rank was introduced, a program for extreme multi-label classification tasks.
By combining a frozen retriever with two in-context learning modules, the approach demonstrates state-of-the-art performance across multiple benchmarks, including ESCO and BioDEX datasets.
This methodology highlights the potential of modular and optimized programs in overcoming the complexities and limitations often associated with prompt and pipeline engineering.
The findings underscore that robust, general-purpose solutions can be achieved without the need for extensive finetuning or large amounts of data. The success of Infer–Retrieve–Rank not only sets a new standard for multi-label classification but also paves the way for future advancements in the field. This approach exemplifies how a well-structured and modular design, facilitated by frameworks like DSPy, can deliver high efficiency, adaptability, and scalability.
The promising results of Infer–Retrieve–Rank suggest a shift towards more resilient and efficient methods in prompt and pipeline engineering. As the landscape of machine learning and natural language processing continues to evolve, such modular programs offer a glimpse into a future where complex tasks can be managed with simplicity and precision.
Large Language Models (LLMs) play a role in the Infer–Retrieve–Rank process, enabling efficient and accurate multi-label classification.
Here's a detailed explanation of their roles and interactions:
Infer Module:
Student LM (Llama-2-7b-chat): This model processes the input text and generates initial label predictions. It acts as the first layer of understanding, leveraging its trained knowledge to infer potential labels from the provided input.
Teacher LM (GPT-3.5-turbo): This model helps to optimise the student LM by providing guidance and feedback during the training phase. The teacher model's responses are used to improve the accuracy and relevance of the student LM's predictions.
Retrieve Module:
Frozen Retriever (BioLORD for BioDEX and all-mpnet-base-v2 for ESCO tasks): This component uses embeddings to map the initial label predictions from the Infer module to the actual label space. The retriever helps refine the set of potential labels by comparing them against a pre-trained database of embeddings.
Role: The retriever is essential for narrowing down the wide range of possible labels to a more relevant subset based on the context provided by the Infer module.
Rank Module:
Student LM (GPT-4): This model takes the initial predictions and retrieved labels and re-ranks them to prioritise the most relevant ones. It further refines the label set by considering both the original input text and the retrieved labels.
Teacher LM (GPT-4): Similar to the Infer module's teacher, this model helps optimise the ranking process by providing high-quality feedback and corrections during training.
Inference:
The Infer module uses the student LM (Llama-2-7b-chat) to analyse the input text and predict possible labels.
The predicted labels are then parsed and structured for further processing.
Retrieval:
The Retrieve module uses the frozen retriever to find the most relevant labels from a pre-trained embedding space based on the predictions from the Infer module.
This step ensures that the labels are not only relevant but also contextually accurate.
Reranking:
The Rank module uses the student LM (GPT-4) to re-evaluate and reorder the retrieved labels, ensuring that the most relevant labels are prioritised.
The rank module considers both the input text and the initial predictions to fine-tune the final label set.
Optimization and Training
During the training phase, the teacher LMs (GPT-3.5-turbo for Infer and GPT-4 for Rank) guide the student LMs by providing examples and corrections.
The optimization process involves multiple calls to the teacher models to refine the predictions and rankings, reducing errors and improving overall performance.
In the Infer–Retrieve–Rank process:
LLMs are integral to each step, from initial inference to retrieval and reranking.
Student LMs handle the core tasks of prediction and ranking, while teacher LMs provide optimisation and guidance.
The combination of these models allows for a highly modular, adaptable, and efficient approach to extreme multi-label classification, capable of achieving state-of-the-art performance with minimal data and no finetuning.