# rag_pipeline.py import numpy as np import pickle import os import logging import asyncio from app.search.bm25_search import BM25_search from app.search.faiss_search import FAISS_search from app.search.hybrid_search import Hybrid_search from app.utils.token_counter import TokenCounter # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) from keybert import KeyBERT import asyncio def extract_keywords_async(doc, threshold=0.4, top_n = 5): kw_model = KeyBERT() keywords = kw_model.extract_keywords(doc, threshold=threshold, top_n=top_n) keywords = [key for key, _ in keywords] return keywords # rag.py class RAGSystem: def __init__(self, embedding_model): self.token_counter = TokenCounter() self.documents = [] self.doc_ids = [] self.results = [] self.meta_data = [] self.embedding_model = embedding_model self.bm25_wrapper = BM25_search() self.faiss_wrapper = FAISS_search(embedding_model) self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper) def add_document(self, doc_id, text, meta_data=None): self.token_counter.add_document(doc_id, text) self.doc_ids.append(doc_id) self.documents.append(text) self.meta_data.append(meta_data) self.bm25_wrapper.add_document(doc_id, text) self.faiss_wrapper.add_document(doc_id, text) def delete_document(self, doc_id): try: index = self.doc_ids.index(doc_id) del self.doc_ids[index] del self.documents[index] self.bm25_wrapper.remove_document(index) self.faiss_wrapper.remove_document(index) self.token_counter.remove_document(doc_id) except ValueError: logging.warning(f"Document ID {doc_id} not found.") async def adv_query(self, query_text, keywords, top_k=15, prefixes=None): results = await self.hybrid_search.advanced_search( query_text, keywords=keywords, top_n=top_k, threshold=0.43, prefixes=prefixes ) retrieved_docs = [] if results: seen_docs = set() for doc_id, score in results: if doc_id not in seen_docs: # Check if the doc_id exists in self.doc_ids if doc_id not in self.doc_ids: logger.error(f"doc_id {doc_id} not found in self.doc_ids") seen_docs.add(doc_id) # Fetch the index of the document try: index = self.doc_ids.index(doc_id) except ValueError as e: logger.error(f"Error finding index for doc_id {doc_id}: {e}") continue # Validate index range if index >= len(self.documents) or index >= len(self.meta_data): logger.error(f"Index {index} out of range for documents or metadata") continue doc = self.documents[index] meta_data = self.meta_data[index] # Extract the file name and page number # file_name = meta_data['source'].split('/')[-1] # Extracts 'POJK 31 - 2018.pdf' # page_number = meta_data.get('page', 'unknown') # url = meta_data['source'] # file_name = meta_data.get('source', 'unknown_source').split('/')[-1] # Safe extraction # page_number = meta_data.get('page', 'unknown') # Default to 'unknown' if 'page' is missing url = meta_data.get('source', 'unknown_url') # Default URL fallback # logger.info(f"file_name: {file_name}, page_number: {page_number}, url: {url}") # Format as a single string # content_string = f"'{file_name}', 'page': {page_number}" # doc_name = f"{file_name}" self.results.append(doc) retrieved_docs.append({"url":url, "text": doc}) return retrieved_docs else: return [{"url": "None.", "text": None}] def get_total_tokens(self): return self.token_counter.get_total_tokens() def get_context(self): context = "\n".join(self.results) return context def save_state(self, path): # Save doc_ids, documents, and token counter state with open(f"{path}_state.pkl", 'wb') as f: pickle.dump({ "doc_ids": self.doc_ids, "documents": self.documents, "meta_data": self.meta_data, "token_counts": self.token_counter.doc_tokens }, f) def load_state(self, path): if os.path.exists(f"{path}_state.pkl"): with open(f"{path}_state.pkl", 'rb') as f: state_data = pickle.load(f) self.doc_ids = state_data["doc_ids"] self.documents = state_data["documents"] self.meta_data = state_data["meta_data"] self.token_counter.doc_tokens = state_data["token_counts"] # Clear and rebuild BM25 and FAISS self.bm25_wrapper.clear_documents() self.faiss_wrapper.clear_documents() for doc_id, document in zip(self.doc_ids, self.documents): self.bm25_wrapper.add_document(doc_id, document) self.faiss_wrapper.add_document(doc_id, document) self.token_counter.total_tokens = sum(self.token_counter.doc_tokens.values()) logging.info("System state loaded successfully with documents and indices rebuilt.") else: logging.info("No previous state found, initializing fresh state.") self.doc_ids = [] self.documents = [] self.meta_data = [] # Reset meta_data self.token_counter = TokenCounter() self.bm25_wrapper = BM25_search() self.faiss_wrapper = FAISS_search(self.embedding_model) self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)