File size: 2,859 Bytes
4b549a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from abc import ABC, abstractmethod

import pandas as pd
import torch
from datasets import load_from_disk
from sentence_transformers import SentenceTransformer

# from finbert_embedding.embedding import FinbertEmbedding


class TextEmbedder(ABC):
    def __init__(self, model_name, paragraphs_path, device, load_existing_index=False):
        self.dataset = load_from_disk(paragraphs_path)
        self.model = self._load_model(model_name, device)

        assert len(self.dataset) > 0, "The loaded dataset is empty !!"

        if load_existing_index == True:
            self.dataset.load_faiss_index(
                "embeddings", f"{paragraphs_path}/index.faiss"
            )

    def generate_paragraphs_embedding(self):
        self.dataset = self.dataset.map(
            lambda x: {"embeddings": self._generate_embeddings(x["content"])}
        )

    def save_embeddings(self, output_path):
        self.dataset.add_faiss_index(column="embeddings")
        self.dataset.save_faiss_index("embeddings", f"{output_path}/index.faiss")

    def retrieve_faiss(self, query: str, k_total: int, threshold: int):
        question_embedding = self._generate_embeddings(query)
        scores, samples = self.dataset.get_nearest_examples(
            "embeddings", question_embedding, k=k_total
        )
        passages_df = pd.DataFrame(samples)
        passages_df["scores"] = scores / 100
        passages_df = passages_df[passages_df["scores"] > threshold]
        passages_df = passages_df.sort_values(by=["scores"], ascending=False)

        if len(passages_df) == 0:
            return [], []

        contents = passages_df["content"].tolist()
        meta = passages_df.drop(columns=["content"]).to_dict(orient="records")
        passages = []
        for i in range(len(contents)):
            passages.append({"content": contents[i], "meta": meta[i]})
        return passages, passages_df["scores"].values

    def retrieve_elastic(self, query: str, k_total: int, threshold: int):
        raise NotImplementedError

    @abstractmethod
    def _load_model(self, model_name: str, device: str):
        pass

    @abstractmethod
    def _generate_embeddings(self, text: str):
        pass


class SentenceTransformersTextEmbedder(TextEmbedder):
    def _load_model(self, model_name: str, device: str):
        model = SentenceTransformer(model_name)
        torch_device = torch.device(device)
        model.to(torch_device)
        return model

    def _generate_embeddings(self, text: str):
        return self.model.encode(text)


# class FinBertTextEmbedder(TextEmbedder):
#     def _load_model(self, model_name: str, device: str):
#         model = FinbertEmbedding(device=device)
#         return model

#     def _generate_embeddings(self, text: str):
#         output = self.model.sentence_vector(text)
#         return output.cpu().numpy()