CSRD_GPT / text_embedder.py
AxelFritz1
first commit
7013379
from abc import ABC, abstractmethod
import pandas as pd
import torch
from datasets import load_from_disk
from sentence_transformers import SentenceTransformer
# from finbert_embedding.embedding import FinbertEmbedding
class TextEmbedder(ABC):
def __init__(self, model_name, paragraphs_path, device, load_existing_index=False):
"""Initialize an instance of the TextEmbedder class.
Args:
model_name (str): The name of the SentenceTransformer model to be used for embeddings.
paragraphs_path (str): The path to the dataset of paragraphs to be embedded.
device (str): The target device to run the model ('cpu' or 'cuda').
load_existing_index (bool): If True, load an existing Faiss index, if available.
Returns:
None
"""
self.dataset = load_from_disk(paragraphs_path)
self.model = self._load_model(model_name, device)
assert len(self.dataset) > 0, "The loaded dataset is empty !!"
if load_existing_index == True:
self.dataset.load_faiss_index(
"embeddings", f"{paragraphs_path}/index.faiss"
)
# Generate embeddings for each paragraph
def generate_paragraphs_embedding(self):
"""Generate embeddings for paragraphs in the dataset.
This function computes embeddings for each paragraph's content in the dataset and adds
the embeddings as a new column named "embeddings" to the dataset.
Args:
None
Returns:
None
"""
self.dataset = self.dataset.map(
lambda x: {"embeddings": self._generate_embeddings(x["content"])}
)
# Save embeddings
def save_embeddings(self, output_path):
"""Save Faiss embeddings index to a specified output path.
Args:
output_path (str): The path to save the Faiss embeddings index.
Returns:
None
"""
self.dataset.add_faiss_index(column="embeddings")
self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss")
# Allows the search
def retrieve_faiss(self, query: str, k_total: int, threshold: int):
"""Retrieve passages using Faiss similarity search.
Args:
query (str): The query for which similar passages are to be retrieved.
k_total (int): The total number of passages to retrieve.
threshold (int): The minimum similarity score threshold for passages to be considered.
Returns:
Tuple[List[Dict[str, Union[str, Dict[str, Any]]], np.ndarray]]:
A tuple containing:
- List of dictionaries, each representing a passage with 'content' (str) and 'meta' (dict) fields.
- Numpy array of similarity scores for the retrieved passages.
"""
question_embedding = self._generate_embeddings(query)
scores, samples = self.dataset.get_nearest_examples(
"embeddings", question_embedding, k=k_total
)
passages_df = pd.DataFrame(samples)
passages_df["scores"] = scores / 100
passages_df = passages_df[passages_df["scores"] > threshold]
passages_df = passages_df.sort_values(by=["scores"], ascending=False)
if len(passages_df) == 0:
return [], []
contents = passages_df["content"].tolist()
meta = passages_df.drop(columns=["content"]).to_dict(orient="records")
passages = []
for i in range(len(contents)):
passages.append({"content": contents[i], "meta": meta[i]})
return passages, passages_df["scores"].values
def retrieve_elastic(self, query: str, k_total: int, threshold: int):
raise NotImplementedError
@abstractmethod
def _load_model(self, model_name: str, device: str):
pass
@abstractmethod
def _generate_embeddings(self, text: str):
pass
class SentenceTransformersTextEmbedder(TextEmbedder):
def _load_model(self, model_name: str, device: str):
"""Load a SentenceTransformer model onto the specified device.
Args:
model_name (str): The name of the SentenceTransformer model to be loaded.
device (str): The target device to move the model to ('cpu' or 'cuda').
Returns:
SentenceTransformer: The loaded SentenceTransformer model placed on the specified device.
"""
model = SentenceTransformer(model_name)
torch_device = torch.device(device)
model.to(torch_device)
return model
def _generate_embeddings(self, text: str):
"""Generate embeddings for a given text using the loaded model.
Args:
text (str): The input text for which embeddings are to be generated.
Returns:
np.ndarray: An array representing the embeddings of the input text.
"""
return self.model.encode(text)
# class FinBertTextEmbedder(TextEmbedder):
# def _load_model(self, model_name: str, device: str):
# model = FinbertEmbedding(device=device)
# return model
# def _generate_embeddings(self, text: str):
# output = self.model.sentence_vector(text)
# return output.cpu().numpy()