|
from abc import ABC, abstractmethod |
|
|
|
import pandas as pd |
|
import torch |
|
from datasets import load_from_disk |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
class TextEmbedder(ABC): |
|
def __init__(self, model_name, paragraphs_path, device, load_existing_index=False): |
|
"""Initialize an instance of the TextEmbedder class. |
|
Args: |
|
model_name (str): The name of the SentenceTransformer model to be used for embeddings. |
|
paragraphs_path (str): The path to the dataset of paragraphs to be embedded. |
|
device (str): The target device to run the model ('cpu' or 'cuda'). |
|
load_existing_index (bool): If True, load an existing Faiss index, if available. |
|
Returns: |
|
None |
|
""" |
|
self.dataset = load_from_disk(paragraphs_path) |
|
self.model = self._load_model(model_name, device) |
|
|
|
assert len(self.dataset) > 0, "The loaded dataset is empty !!" |
|
|
|
if load_existing_index == True: |
|
self.dataset.load_faiss_index( |
|
"embeddings", f"{paragraphs_path}/index.faiss" |
|
) |
|
|
|
|
|
def generate_paragraphs_embedding(self): |
|
"""Generate embeddings for paragraphs in the dataset. |
|
This function computes embeddings for each paragraph's content in the dataset and adds |
|
the embeddings as a new column named "embeddings" to the dataset. |
|
Args: |
|
None |
|
Returns: |
|
None |
|
""" |
|
self.dataset = self.dataset.map( |
|
lambda x: {"embeddings": self._generate_embeddings(x["content"])} |
|
) |
|
|
|
|
|
def save_embeddings(self, output_path): |
|
"""Save Faiss embeddings index to a specified output path. |
|
Args: |
|
output_path (str): The path to save the Faiss embeddings index. |
|
Returns: |
|
None |
|
""" |
|
self.dataset.add_faiss_index(column="embeddings") |
|
self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss") |
|
|
|
|
|
def retrieve_faiss(self, query: str, k_total: int, threshold: int): |
|
"""Retrieve passages using Faiss similarity search. |
|
Args: |
|
query (str): The query for which similar passages are to be retrieved. |
|
k_total (int): The total number of passages to retrieve. |
|
threshold (int): The minimum similarity score threshold for passages to be considered. |
|
Returns: |
|
Tuple[List[Dict[str, Union[str, Dict[str, Any]]], np.ndarray]]: |
|
A tuple containing: |
|
- List of dictionaries, each representing a passage with 'content' (str) and 'meta' (dict) fields. |
|
- Numpy array of similarity scores for the retrieved passages. |
|
""" |
|
question_embedding = self._generate_embeddings(query) |
|
scores, samples = self.dataset.get_nearest_examples( |
|
"embeddings", question_embedding, k=k_total |
|
) |
|
passages_df = pd.DataFrame(samples) |
|
passages_df["scores"] = scores / 100 |
|
passages_df = passages_df[passages_df["scores"] > threshold] |
|
passages_df = passages_df.sort_values(by=["scores"], ascending=False) |
|
|
|
if len(passages_df) == 0: |
|
return [], [] |
|
|
|
contents = passages_df["content"].tolist() |
|
meta = passages_df.drop(columns=["content"]).to_dict(orient="records") |
|
passages = [] |
|
for i in range(len(contents)): |
|
passages.append({"content": contents[i], "meta": meta[i]}) |
|
return passages, passages_df["scores"].values |
|
|
|
def retrieve_elastic(self, query: str, k_total: int, threshold: int): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _load_model(self, model_name: str, device: str): |
|
pass |
|
|
|
@abstractmethod |
|
def _generate_embeddings(self, text: str): |
|
pass |
|
|
|
|
|
class SentenceTransformersTextEmbedder(TextEmbedder): |
|
def _load_model(self, model_name: str, device: str): |
|
"""Load a SentenceTransformer model onto the specified device. |
|
Args: |
|
model_name (str): The name of the SentenceTransformer model to be loaded. |
|
device (str): The target device to move the model to ('cpu' or 'cuda'). |
|
Returns: |
|
SentenceTransformer: The loaded SentenceTransformer model placed on the specified device. |
|
""" |
|
model = SentenceTransformer(model_name) |
|
torch_device = torch.device(device) |
|
model.to(torch_device) |
|
return model |
|
|
|
def _generate_embeddings(self, text: str): |
|
"""Generate embeddings for a given text using the loaded model. |
|
Args: |
|
text (str): The input text for which embeddings are to be generated. |
|
Returns: |
|
np.ndarray: An array representing the embeddings of the input text. |
|
""" |
|
return self.model.encode(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|