chat / app /search /rag_pipeline.py
ariansyahdedy's picture
Test Rag
e0c1af0
# rag_pipeline.py
import numpy as np
import pickle
import os
import logging
import asyncio
from app.search.bm25_search import BM25_search
from app.search.faiss_search import FAISS_search
from app.search.hybrid_search import Hybrid_search
from app.utils.token_counter import TokenCounter
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from keybert import KeyBERT
import asyncio
def extract_keywords_async(doc, threshold=0.4, top_n = 5):
kw_model = KeyBERT()
keywords = kw_model.extract_keywords(doc, threshold=threshold, top_n=top_n)
keywords = [key for key, _ in keywords]
return keywords
# rag.py
class RAGSystem:
def __init__(self, embedding_model):
self.token_counter = TokenCounter()
self.documents = []
self.doc_ids = []
self.results = []
self.meta_data = []
self.embedding_model = embedding_model
self.bm25_wrapper = BM25_search()
self.faiss_wrapper = FAISS_search(embedding_model)
self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)
def add_document(self, doc_id, text, meta_data=None):
self.token_counter.add_document(doc_id, text)
self.doc_ids.append(doc_id)
self.documents.append(text)
self.meta_data.append(meta_data)
self.bm25_wrapper.add_document(doc_id, text)
self.faiss_wrapper.add_document(doc_id, text)
def delete_document(self, doc_id):
try:
index = self.doc_ids.index(doc_id)
del self.doc_ids[index]
del self.documents[index]
self.bm25_wrapper.remove_document(index)
self.faiss_wrapper.remove_document(index)
self.token_counter.remove_document(doc_id)
except ValueError:
logging.warning(f"Document ID {doc_id} not found.")
async def adv_query(self, query_text, keywords, top_k=15, prefixes=None):
results = await self.hybrid_search.advanced_search(
query_text,
keywords=keywords,
top_n=top_k,
threshold=0.43,
prefixes=prefixes
)
retrieved_docs = []
if results:
seen_docs = set()
for doc_id, score in results:
if doc_id not in seen_docs:
# Check if the doc_id exists in self.doc_ids
if doc_id not in self.doc_ids:
logger.error(f"doc_id {doc_id} not found in self.doc_ids")
seen_docs.add(doc_id)
# Fetch the index of the document
try:
index = self.doc_ids.index(doc_id)
except ValueError as e:
logger.error(f"Error finding index for doc_id {doc_id}: {e}")
continue
# Validate index range
if index >= len(self.documents) or index >= len(self.meta_data):
logger.error(f"Index {index} out of range for documents or metadata")
continue
doc = self.documents[index]
meta_data = self.meta_data[index]
# Extract the file name and page number
# file_name = meta_data['source'].split('/')[-1] # Extracts 'POJK 31 - 2018.pdf'
# page_number = meta_data.get('page', 'unknown')
# url = meta_data['source']
# file_name = meta_data.get('source', 'unknown_source').split('/')[-1] # Safe extraction
# page_number = meta_data.get('page', 'unknown') # Default to 'unknown' if 'page' is missing
url = meta_data.get('source', 'unknown_url') # Default URL fallback
# logger.info(f"file_name: {file_name}, page_number: {page_number}, url: {url}")
# Format as a single string
# content_string = f"'{file_name}', 'page': {page_number}"
# doc_name = f"{file_name}"
self.results.append(doc)
retrieved_docs.append({"url":url, "text": doc})
return retrieved_docs
else:
return [{"url": "None.", "text": None}]
def get_total_tokens(self):
return self.token_counter.get_total_tokens()
def get_context(self):
context = "\n".join(self.results)
return context
def save_state(self, path):
# Save doc_ids, documents, and token counter state
with open(f"{path}_state.pkl", 'wb') as f:
pickle.dump({
"doc_ids": self.doc_ids,
"documents": self.documents,
"meta_data": self.meta_data,
"token_counts": self.token_counter.doc_tokens
}, f)
def load_state(self, path):
if os.path.exists(f"{path}_state.pkl"):
with open(f"{path}_state.pkl", 'rb') as f:
state_data = pickle.load(f)
self.doc_ids = state_data["doc_ids"]
self.documents = state_data["documents"]
self.meta_data = state_data["meta_data"]
self.token_counter.doc_tokens = state_data["token_counts"]
# Clear and rebuild BM25 and FAISS
self.bm25_wrapper.clear_documents()
self.faiss_wrapper.clear_documents()
for doc_id, document in zip(self.doc_ids, self.documents):
self.bm25_wrapper.add_document(doc_id, document)
self.faiss_wrapper.add_document(doc_id, document)
self.token_counter.total_tokens = sum(self.token_counter.doc_tokens.values())
logging.info("System state loaded successfully with documents and indices rebuilt.")
else:
logging.info("No previous state found, initializing fresh state.")
self.doc_ids = []
self.documents = []
self.meta_data = [] # Reset meta_data
self.token_counter = TokenCounter()
self.bm25_wrapper = BM25_search()
self.faiss_wrapper = FAISS_search(self.embedding_model)
self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)