File size: 5,263 Bytes
7013379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from abc import ABC, abstractmethod
import pandas as pd
import torch
from datasets import load_from_disk
from sentence_transformers import SentenceTransformer
# from finbert_embedding.embedding import FinbertEmbedding
class TextEmbedder(ABC):
def __init__(self, model_name, paragraphs_path, device, load_existing_index=False):
"""Initialize an instance of the TextEmbedder class.
Args:
model_name (str): The name of the SentenceTransformer model to be used for embeddings.
paragraphs_path (str): The path to the dataset of paragraphs to be embedded.
device (str): The target device to run the model ('cpu' or 'cuda').
load_existing_index (bool): If True, load an existing Faiss index, if available.
Returns:
None
"""
self.dataset = load_from_disk(paragraphs_path)
self.model = self._load_model(model_name, device)
assert len(self.dataset) > 0, "The loaded dataset is empty !!"
if load_existing_index == True:
self.dataset.load_faiss_index(
"embeddings", f"{paragraphs_path}/index.faiss"
)
# Generate embeddings for each paragraph
def generate_paragraphs_embedding(self):
"""Generate embeddings for paragraphs in the dataset.
This function computes embeddings for each paragraph's content in the dataset and adds
the embeddings as a new column named "embeddings" to the dataset.
Args:
None
Returns:
None
"""
self.dataset = self.dataset.map(
lambda x: {"embeddings": self._generate_embeddings(x["content"])}
)
# Save embeddings
def save_embeddings(self, output_path):
"""Save Faiss embeddings index to a specified output path.
Args:
output_path (str): The path to save the Faiss embeddings index.
Returns:
None
"""
self.dataset.add_faiss_index(column="embeddings")
self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss")
# Allows the search
def retrieve_faiss(self, query: str, k_total: int, threshold: int):
"""Retrieve passages using Faiss similarity search.
Args:
query (str): The query for which similar passages are to be retrieved.
k_total (int): The total number of passages to retrieve.
threshold (int): The minimum similarity score threshold for passages to be considered.
Returns:
Tuple[List[Dict[str, Union[str, Dict[str, Any]]], np.ndarray]]:
A tuple containing:
- List of dictionaries, each representing a passage with 'content' (str) and 'meta' (dict) fields.
- Numpy array of similarity scores for the retrieved passages.
"""
question_embedding = self._generate_embeddings(query)
scores, samples = self.dataset.get_nearest_examples(
"embeddings", question_embedding, k=k_total
)
passages_df = pd.DataFrame(samples)
passages_df["scores"] = scores / 100
passages_df = passages_df[passages_df["scores"] > threshold]
passages_df = passages_df.sort_values(by=["scores"], ascending=False)
if len(passages_df) == 0:
return [], []
contents = passages_df["content"].tolist()
meta = passages_df.drop(columns=["content"]).to_dict(orient="records")
passages = []
for i in range(len(contents)):
passages.append({"content": contents[i], "meta": meta[i]})
return passages, passages_df["scores"].values
def retrieve_elastic(self, query: str, k_total: int, threshold: int):
raise NotImplementedError
@abstractmethod
def _load_model(self, model_name: str, device: str):
pass
@abstractmethod
def _generate_embeddings(self, text: str):
pass
class SentenceTransformersTextEmbedder(TextEmbedder):
def _load_model(self, model_name: str, device: str):
"""Load a SentenceTransformer model onto the specified device.
Args:
model_name (str): The name of the SentenceTransformer model to be loaded.
device (str): The target device to move the model to ('cpu' or 'cuda').
Returns:
SentenceTransformer: The loaded SentenceTransformer model placed on the specified device.
"""
model = SentenceTransformer(model_name)
torch_device = torch.device(device)
model.to(torch_device)
return model
def _generate_embeddings(self, text: str):
"""Generate embeddings for a given text using the loaded model.
Args:
text (str): The input text for which embeddings are to be generated.
Returns:
np.ndarray: An array representing the embeddings of the input text.
"""
return self.model.encode(text)
# class FinBertTextEmbedder(TextEmbedder):
# def _load_model(self, model_name: str, device: str):
# model = FinbertEmbedding(device=device)
# return model
# def _generate_embeddings(self, text: str):
# output = self.model.sentence_vector(text)
# return output.cpu().numpy()
|