|
from abc import ABC, abstractmethod |
|
|
|
import pandas as pd |
|
import torch |
|
from datasets import load_from_disk |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
class TextEmbedder(ABC): |
|
def __init__(self, model_name, paragraphs_path, device, load_existing_index=False): |
|
self.dataset = load_from_disk(paragraphs_path) |
|
self.model = self._load_model(model_name, device) |
|
|
|
assert len(self.dataset) > 0, "The loaded dataset is empty !!" |
|
|
|
if load_existing_index == True: |
|
self.dataset.load_faiss_index( |
|
"embeddings", f"{paragraphs_path}/index.faiss" |
|
) |
|
|
|
def generate_paragraphs_embedding(self): |
|
self.dataset = self.dataset.map( |
|
lambda x: {"embeddings": self._generate_embeddings(x["content"])} |
|
) |
|
|
|
def save_embeddings(self, output_path): |
|
self.dataset.add_faiss_index(column="embeddings") |
|
self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss") |
|
|
|
def retrieve_faiss(self, query: str, k_total: int, threshold: int): |
|
question_embedding = self._generate_embeddings(query) |
|
scores, samples = self.dataset.get_nearest_examples( |
|
"embeddings", question_embedding, k=k_total |
|
) |
|
passages_df = pd.DataFrame(samples) |
|
passages_df["scores"] = scores / 100 |
|
passages_df = passages_df[passages_df["scores"] > threshold] |
|
passages_df = passages_df.sort_values(by=["scores"], ascending=False) |
|
|
|
if len(passages_df) == 0: |
|
return [], [] |
|
|
|
contents = passages_df["content"].tolist() |
|
meta = passages_df.drop(columns=["content"]).to_dict(orient="records") |
|
passages = [] |
|
for i in range(len(contents)): |
|
passages.append({"content": contents[i], "meta": meta[i]}) |
|
return passages, passages_df["scores"].values |
|
|
|
def retrieve_elastic(self, query: str, k_total: int, threshold: int): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _load_model(self, model_name: str, device: str): |
|
pass |
|
|
|
@abstractmethod |
|
def _generate_embeddings(self, text: str): |
|
pass |
|
|
|
|
|
class SentenceTransformersTextEmbedder(TextEmbedder): |
|
def _load_model(self, model_name: str, device: str): |
|
model = SentenceTransformer(model_name) |
|
torch_device = torch.device(device) |
|
model.to(torch_device) |
|
return model |
|
|
|
def _generate_embeddings(self, text: str): |
|
return self.model.encode(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|