File size: 6,333 Bytes
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be9a762
 
 
 
e0c1af0
be9a762
e0c1af0
be9a762
 
 
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
e0c1af0
8d2f9d4
 
 
e0c1af0
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c1af0
8d2f9d4
 
 
 
 
 
e0c1af0
8d2f9d4
 
 
 
 
 
 
 
e0c1af0
8d2f9d4
 
e0c1af0
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# rag_pipeline.py

import numpy as np

import pickle
import os
import logging
import asyncio

from app.search.bm25_search import BM25_search
from app.search.faiss_search import FAISS_search
from app.search.hybrid_search import Hybrid_search
from app.utils.token_counter import TokenCounter


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


from keybert import KeyBERT
import asyncio

def extract_keywords_async(doc, threshold=0.4, top_n = 5):
    kw_model = KeyBERT()
    keywords = kw_model.extract_keywords(doc, threshold=threshold, top_n=top_n)
    keywords = [key for key, _ in keywords]
    return keywords

# rag.py
class RAGSystem:
    def __init__(self, embedding_model):
        self.token_counter = TokenCounter()
        self.documents = []
        self.doc_ids = []
        self.results = []
        self.meta_data = []
        self.embedding_model = embedding_model
        self.bm25_wrapper = BM25_search()
        self.faiss_wrapper = FAISS_search(embedding_model)
        self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)

    def add_document(self, doc_id, text, meta_data=None):
        self.token_counter.add_document(doc_id, text)
        self.doc_ids.append(doc_id)
        self.documents.append(text)
        self.meta_data.append(meta_data)
        self.bm25_wrapper.add_document(doc_id, text)
        self.faiss_wrapper.add_document(doc_id, text)

    def delete_document(self, doc_id):
        try:
            index = self.doc_ids.index(doc_id)
            del self.doc_ids[index]
            del self.documents[index]
            self.bm25_wrapper.remove_document(index)
            self.faiss_wrapper.remove_document(index)
            self.token_counter.remove_document(doc_id)
        except ValueError:
            logging.warning(f"Document ID {doc_id} not found.")

    async def adv_query(self, query_text, keywords, top_k=15, prefixes=None):
        results = await self.hybrid_search.advanced_search(
            query_text,
            keywords=keywords,
            top_n=top_k,
            threshold=0.43,
            prefixes=prefixes
        )
        retrieved_docs = []
        if results:
            seen_docs = set()
            for doc_id, score in results:
                if doc_id not in seen_docs:
                     # Check if the doc_id exists in self.doc_ids
                    if doc_id not in self.doc_ids:
                        logger.error(f"doc_id {doc_id} not found in self.doc_ids")
                    seen_docs.add(doc_id)
                  
                    # Fetch the index of the document
                    try:
                        index = self.doc_ids.index(doc_id)
                    except ValueError as e:
                        logger.error(f"Error finding index for doc_id {doc_id}: {e}")
                        continue

                     # Validate index range
                    if index >= len(self.documents) or index >= len(self.meta_data):
                        logger.error(f"Index {index} out of range for documents or metadata")
                        continue

                    doc = self.documents[index]
                    
                    meta_data = self.meta_data[index]
                    # Extract the file name and page number
                    # file_name = meta_data['source'].split('/')[-1]  # Extracts 'POJK 31 - 2018.pdf'
                    # page_number = meta_data.get('page', 'unknown')
                    # url = meta_data['source']
                    # file_name = meta_data.get('source', 'unknown_source').split('/')[-1]  # Safe extraction
                    # page_number = meta_data.get('page', 'unknown')  # Default to 'unknown' if 'page' is missing
                    url = meta_data.get('source', 'unknown_url')  # Default URL fallback

                    # logger.info(f"file_name: {file_name}, page_number: {page_number}, url: {url}")

                    # Format as a single string
                    # content_string = f"'{file_name}', 'page': {page_number}"
                    # doc_name = f"{file_name}"
                  
                    self.results.append(doc)
                    retrieved_docs.append({"url":url, "text": doc})
            return retrieved_docs
        else:
            return [{"url": "None.", "text": None}]

    def get_total_tokens(self):
        return self.token_counter.get_total_tokens()
    def get_context(self):
        context = "\n".join(self.results)
        return context

    def save_state(self, path):
    # Save doc_ids, documents, and token counter state
        with open(f"{path}_state.pkl", 'wb') as f:
            pickle.dump({
                "doc_ids": self.doc_ids,
                "documents": self.documents,
                "meta_data": self.meta_data,
                "token_counts": self.token_counter.doc_tokens
            }, f)

    def load_state(self, path):
        if os.path.exists(f"{path}_state.pkl"):
            with open(f"{path}_state.pkl", 'rb') as f:
                state_data = pickle.load(f)
                self.doc_ids = state_data["doc_ids"]
                self.documents = state_data["documents"]
                self.meta_data = state_data["meta_data"]
                self.token_counter.doc_tokens = state_data["token_counts"]

            # Clear and rebuild BM25 and FAISS
            self.bm25_wrapper.clear_documents()
            self.faiss_wrapper.clear_documents()
            for doc_id, document in zip(self.doc_ids, self.documents):
                self.bm25_wrapper.add_document(doc_id, document)
                self.faiss_wrapper.add_document(doc_id, document)

            self.token_counter.total_tokens = sum(self.token_counter.doc_tokens.values())
            logging.info("System state loaded successfully with documents and indices rebuilt.")
        else:
            logging.info("No previous state found, initializing fresh state.")
            self.doc_ids = []
            self.documents = []
            self.meta_data = []  # Reset meta_data
            self.token_counter = TokenCounter()
            self.bm25_wrapper = BM25_search()
            self.faiss_wrapper = FAISS_search(self.embedding_model)
            self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)