from abc import ABC, abstractmethod import pandas as pd import torch from datasets import load_from_disk from sentence_transformers import SentenceTransformer # from finbert_embedding.embedding import FinbertEmbedding class TextEmbedder(ABC): def __init__(self, model_name, paragraphs_path, device, load_existing_index=False): """Initialize an instance of the TextEmbedder class. Args: model_name (str): The name of the SentenceTransformer model to be used for embeddings. paragraphs_path (str): The path to the dataset of paragraphs to be embedded. device (str): The target device to run the model ('cpu' or 'cuda'). load_existing_index (bool): If True, load an existing Faiss index, if available. Returns: None """ self.dataset = load_from_disk(paragraphs_path) self.model = self._load_model(model_name, device) assert len(self.dataset) > 0, "The loaded dataset is empty !!" if load_existing_index == True: self.dataset.load_faiss_index( "embeddings", f"{paragraphs_path}/index.faiss" ) # Generate embeddings for each paragraph def generate_paragraphs_embedding(self): """Generate embeddings for paragraphs in the dataset. This function computes embeddings for each paragraph's content in the dataset and adds the embeddings as a new column named "embeddings" to the dataset. Args: None Returns: None """ self.dataset = self.dataset.map( lambda x: {"embeddings": self._generate_embeddings(x["content"])} ) # Save embeddings def save_embeddings(self, output_path): """Save Faiss embeddings index to a specified output path. Args: output_path (str): The path to save the Faiss embeddings index. Returns: None """ self.dataset.add_faiss_index(column="embeddings") self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss") # Allows the search def retrieve_faiss(self, query: str, k_total: int, threshold: int): """Retrieve passages using Faiss similarity search. Args: query (str): The query for which similar passages are to be retrieved. k_total (int): The total number of passages to retrieve. threshold (int): The minimum similarity score threshold for passages to be considered. Returns: Tuple[List[Dict[str, Union[str, Dict[str, Any]]], np.ndarray]]: A tuple containing: - List of dictionaries, each representing a passage with 'content' (str) and 'meta' (dict) fields. - Numpy array of similarity scores for the retrieved passages. """ question_embedding = self._generate_embeddings(query) scores, samples = self.dataset.get_nearest_examples( "embeddings", question_embedding, k=k_total ) passages_df = pd.DataFrame(samples) passages_df["scores"] = scores / 100 passages_df = passages_df[passages_df["scores"] > threshold] passages_df = passages_df.sort_values(by=["scores"], ascending=False) if len(passages_df) == 0: return [], [] contents = passages_df["content"].tolist() meta = passages_df.drop(columns=["content"]).to_dict(orient="records") passages = [] for i in range(len(contents)): passages.append({"content": contents[i], "meta": meta[i]}) return passages, passages_df["scores"].values def retrieve_elastic(self, query: str, k_total: int, threshold: int): raise NotImplementedError @abstractmethod def _load_model(self, model_name: str, device: str): pass @abstractmethod def _generate_embeddings(self, text: str): pass class SentenceTransformersTextEmbedder(TextEmbedder): def _load_model(self, model_name: str, device: str): """Load a SentenceTransformer model onto the specified device. Args: model_name (str): The name of the SentenceTransformer model to be loaded. device (str): The target device to move the model to ('cpu' or 'cuda'). Returns: SentenceTransformer: The loaded SentenceTransformer model placed on the specified device. """ model = SentenceTransformer(model_name) torch_device = torch.device(device) model.to(torch_device) return model def _generate_embeddings(self, text: str): """Generate embeddings for a given text using the loaded model. Args: text (str): The input text for which embeddings are to be generated. Returns: np.ndarray: An array representing the embeddings of the input text. """ return self.model.encode(text) # class FinBertTextEmbedder(TextEmbedder): # def _load_model(self, model_name: str, device: str): # model = FinbertEmbedding(device=device) # return model # def _generate_embeddings(self, text: str): # output = self.model.sentence_vector(text) # return output.cpu().numpy()