from abc import ABC, abstractmethod import pandas as pd import torch from datasets import load_from_disk from sentence_transformers import SentenceTransformer # from finbert_embedding.embedding import FinbertEmbedding class TextEmbedder(ABC): def __init__(self, model_name, paragraphs_path, device, load_existing_index=False): self.dataset = load_from_disk(paragraphs_path) self.model = self._load_model(model_name, device) assert len(self.dataset) > 0, "The loaded dataset is empty !!" if load_existing_index == True: self.dataset.load_faiss_index( "embeddings", f"{paragraphs_path}/index.faiss" ) def generate_paragraphs_embedding(self): self.dataset = self.dataset.map( lambda x: {"embeddings": self._generate_embeddings(x["content"])} ) def save_embeddings(self, output_path): self.dataset.add_faiss_index(column="embeddings") self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss") def retrieve_faiss(self, query: str, k_total: int, threshold: int): question_embedding = self._generate_embeddings(query) scores, samples = self.dataset.get_nearest_examples( "embeddings", question_embedding, k=k_total ) passages_df = pd.DataFrame(samples) passages_df["scores"] = scores / 100 passages_df = passages_df[passages_df["scores"] > threshold] passages_df = passages_df.sort_values(by=["scores"], ascending=False) if len(passages_df) == 0: return [], [] contents = passages_df["content"].tolist() meta = passages_df.drop(columns=["content"]).to_dict(orient="records") passages = [] for i in range(len(contents)): passages.append({"content": contents[i], "meta": meta[i]}) return passages, passages_df["scores"].values def retrieve_elastic(self, query: str, k_total: int, threshold: int): raise NotImplementedError @abstractmethod def _load_model(self, model_name: str, device: str): pass @abstractmethod def _generate_embeddings(self, text: str): pass class SentenceTransformersTextEmbedder(TextEmbedder): def _load_model(self, model_name: str, device: str): model = SentenceTransformer(model_name) torch_device = torch.device(device) model.to(torch_device) return model def _generate_embeddings(self, text: str): return self.model.encode(text) # class FinBertTextEmbedder(TextEmbedder): # def _load_model(self, model_name: str, device: str): # model = FinbertEmbedding(device=device) # return model # def _generate_embeddings(self, text: str): # output = self.model.sentence_vector(text) # return output.cpu().numpy()