File size: 3,377 Bytes
36623c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beb7154
36623c8
 
beb7154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

from langchain_core.documents import Document

def init_cache():
    index = faiss.IndexFlatL2(1024)
    if index.is_trained:
        print("Index trained")

    # Initialize Sentence Transformer model
    encoder = SentenceTransformer("multilingual-e5-large")

    return index, encoder

def retrieve_cache(json_file):
    try:
        with open(json_file, "r") as file:
            cache = json.load(file)
    except FileNotFoundError:
        cache = {"query": [], "embeddings": [], "answers": [], "response_text": []}

    return cache

def store_cache(json_file, cache):
    with open(json_file, "w") as file:
        json.dump(cache, file)

class SemanticCache:
    def __init__(self, retriever, json_file="cache_file.json", thresold=0.35):
        # Initialize Faiss index with Euclidean distance
        self.retriever = retriever
        self.index, self.encoder = init_cache()

        # Set Euclidean distance threshold
        # a distance of 0 means identicals sentences
        # We only return from cache sentences under this thresold
        self.euclidean_threshold = thresold

        self.json_file = json_file
        self.cache = retrieve_cache(self.json_file)

    def query_database(self, query_text):
        results = self.retriever.get_relevant_documents(query_text)
        return results

    def get_relevant_documents(self, query: str, use_cache=True) -> str:
        # Method to retrieve an answer from the cache or generate a new one
        start_time = time.time()
        try:
            # First we obtain the embeddings corresponding to the user query
            embedding = self.encoder.encode([query])

            # Search for the nearest neighbor in the index
            self.index.nprobe = 8
            D, I = self.index.search(embedding, 1)

            if use_cache:
                if D[0] >= 0:
                    if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                        row_id = int(I[0][0])

                        print("Answer recovered from Cache. ")
                        print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
                        print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")

                        end_time = time.time()
                        elapsed_time = end_time - start_time
                        print(f"Time taken: {elapsed_time:.3f} seconds")
                        return [Document(**doc) for doc in self.cache["answers"][row_id]]

            # Handle the case when there are not enough results
            # or Euclidean distance is not met, asking to chromaDB.
            answer = self.query_database(query)
            # response_text = answer["documents"][0][0]

            self.cache["query"].append(query)
            self.cache["embeddings"].append(embedding[0].tolist())
            self.cache["answers"].append([doc.__dict__ for doc in answer])


            self.index.add(embedding)
            store_cache(self.json_file, self.cache)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Time taken: {elapsed_time:.3f} seconds")

            return answer 
        except Exception as e:
            raise RuntimeError(f"Error during 'get_relevant_documents' method: {e}")