ESMA-GPT / text_embedder.py
vnguyen-nexialog's picture
initial push
4b549a4
from abc import ABC, abstractmethod
import pandas as pd
import torch
from datasets import load_from_disk
from sentence_transformers import SentenceTransformer
# from finbert_embedding.embedding import FinbertEmbedding
class TextEmbedder(ABC):
def __init__(self, model_name, paragraphs_path, device, load_existing_index=False):
self.dataset = load_from_disk(paragraphs_path)
self.model = self._load_model(model_name, device)
assert len(self.dataset) > 0, "The loaded dataset is empty !!"
if load_existing_index == True:
self.dataset.load_faiss_index(
"embeddings", f"{paragraphs_path}/index.faiss"
)
def generate_paragraphs_embedding(self):
self.dataset = self.dataset.map(
lambda x: {"embeddings": self._generate_embeddings(x["content"])}
)
def save_embeddings(self, output_path):
self.dataset.add_faiss_index(column="embeddings")
self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss")
def retrieve_faiss(self, query: str, k_total: int, threshold: int):
question_embedding = self._generate_embeddings(query)
scores, samples = self.dataset.get_nearest_examples(
"embeddings", question_embedding, k=k_total
)
passages_df = pd.DataFrame(samples)
passages_df["scores"] = scores / 100
passages_df = passages_df[passages_df["scores"] > threshold]
passages_df = passages_df.sort_values(by=["scores"], ascending=False)
if len(passages_df) == 0:
return [], []
contents = passages_df["content"].tolist()
meta = passages_df.drop(columns=["content"]).to_dict(orient="records")
passages = []
for i in range(len(contents)):
passages.append({"content": contents[i], "meta": meta[i]})
return passages, passages_df["scores"].values
def retrieve_elastic(self, query: str, k_total: int, threshold: int):
raise NotImplementedError
@abstractmethod
def _load_model(self, model_name: str, device: str):
pass
@abstractmethod
def _generate_embeddings(self, text: str):
pass
class SentenceTransformersTextEmbedder(TextEmbedder):
def _load_model(self, model_name: str, device: str):
model = SentenceTransformer(model_name)
torch_device = torch.device(device)
model.to(torch_device)
return model
def _generate_embeddings(self, text: str):
return self.model.encode(text)
# class FinBertTextEmbedder(TextEmbedder):
# def _load_model(self, model_name: str, device: str):
# model = FinbertEmbedding(device=device)
# return model
# def _generate_embeddings(self, text: str):
# output = self.model.sentence_vector(text)
# return output.cpu().numpy()