Page cover image

Embedding Model Construction

The process of creating an embedding model using LLaMA-2-7B and the Hugging Face libraries with PyTorch.

We'll follow the LLM2Vec approach described in the paper.

Step 1: Install the necessary libraries First, make sure you have the required libraries installed:

pip install torch transformers

Step 2: Load the pre-trained LLaMA-2-7B model Load the pre-trained LLaMA-2-7B model using the Hugging Face Transformers library:

from transformers import LlamaForCausalLM, LlamaTokenizer

model_name = "facebook/llama-7b"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name)

Step 3: Enable bidirectional attention To enable bidirectional attention, you need to modify the attention mask in the model's forward pass.

One way to do this is to create a custom model class that inherits from LlamaForCausalLM and overrides the forward method:

import torch
from transformers.models.llama.modeling_llama import LlamaAttention

class LlamaBidirectionalAttention(LlamaForCausalLM):
    def forward(self, input_ids, attention_mask=None, **kwargs):
        if attention_mask is not None:
            attention_mask = torch.ones_like(attention_mask)
        return super().forward(input_ids, attention_mask=attention_mask, **kwargs)

Step 4: Masked Next Token Prediction (MNTP) Implement the MNTP training objective to adapt the model to use bidirectional attention.

You can create a custom training loop or modify an existing language modelling training script to mask a fraction of the input tokens and compute the loss based on the logits obtained from the token representation at the previous position.

def mntp_loss(model, input_ids, attention_mask):
    # Mask a fraction of the input tokens
    masked_input_ids, labels = mask_tokens(input_ids)
    
    # Forward pass with the masked input
    outputs = model(masked_input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    
    # Compute the loss based on the logits at the previous position
    shifted_logits = logits[..., :-1, :].contiguous()
    shifted_labels = labels[..., 1:].contiguous()
    loss = torch.nn.functional.cross_entropy(shifted_logits.view(-1, shifted_logits.size(-1)), shifted_labels.view(-1))
    
    return loss

Step 5: Unsupervised Contrastive Learning (SimCSE)

Apply unsupervised contrastive learning using the SimCSE approach.

Pass the input sequence through the model twice with independently sampled dropout masks to obtain two different representations for the same sequence.

Maximise the similarity between these two representations while minimizing the similarity with representations of other sequences in the batch.

def simcse_loss(model, input_ids, attention_mask):
    # Forward pass with different dropout masks
    outputs1 = model(input_ids, attention_mask=attention_mask)
    outputs2 = model(input_ids, attention_mask=attention_mask)
    
    # Apply pooling to get sequence representations
    pooled_outputs1 = mean_pooling(outputs1, attention_mask)
    pooled_outputs2 = mean_pooling(outputs2, attention_mask)
    
    # Compute the contrastive loss
    loss = contrastive_loss(pooled_outputs1, pooled_outputs2)
    
    return loss

Step 6: Training Combine the MNTP and SimCSE losses and train the model on a suitable dataset, such as English Wikipedia.

You can use a dataset like Wikitext-103 for the MNTP step and a subset of Wikipedia sentences for the unsupervised SimCSE step.

# Training loop
for batch in dataloader:
    input_ids, attention_mask = batch
    
    # Compute MNTP loss
    mntp_loss_value = mntp_loss(model, input_ids, attention_mask)
    
    # Compute SimCSE loss
    simcse_loss_value = simcse_loss(model, input_ids, attention_mask)
    
    # Combine the losses
    total_loss = mntp_loss_value + simcse_loss_value
    
    # Backward pass and optimization step
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

After training, you will have an LLaMA-2-7B model that has been transformed into a text embedding model using the LLM2Vec approach.

You can then use this model to generate embeddings for various downstream tasks.

Note: This is a high-level overview of the process, and you may need to adapt the code snippets to fit your specific requirements and environment.

Last updated

Was this helpful?