|
import concurrent.futures |
|
import os |
|
|
|
from loguru import logger |
|
from qdrant_client.models import FieldCondition, Filter, MatchValue |
|
from openai import OpenAI |
|
|
|
from rag_demo.preprocessing.base import ( |
|
EmbeddedChunk, |
|
) |
|
from rag_demo.rag.base.query import EmbeddedQuery, Query |
|
|
|
from .query_expansion import QueryExpansion |
|
from .reranker import Reranker |
|
from .prompt_templates import AnswerGenerationTemplate |
|
from .source_annotator import SourceAnnotator |
|
from .query_classifier import QueryClassifier |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
def flatten(nested_list: list) -> list: |
|
"""Flatten a list of lists into a single list.""" |
|
|
|
return [item for sublist in nested_list for item in sublist] |
|
|
|
|
|
class RAGPipeline: |
|
def __init__(self, mock: bool = False) -> None: |
|
self._query_expander = QueryExpansion(mock=mock) |
|
self._reranker = Reranker(mock=mock) |
|
self._source_annotator = SourceAnnotator() |
|
self._query_classifier = QueryClassifier(mock=mock) |
|
|
|
def search( |
|
self, |
|
query: str, |
|
k: int = 3, |
|
expand_to_n_queries: int = 3, |
|
) -> list: |
|
query_model = Query.from_str(query) |
|
|
|
n_generated_queries = self._query_expander.generate( |
|
query_model, expand_to_n=expand_to_n_queries |
|
) |
|
logger.info( |
|
f"Successfully generated {len(n_generated_queries)} search queries.", |
|
) |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
search_tasks = [ |
|
executor.submit(self._search, _query_model, k) |
|
for _query_model in n_generated_queries |
|
] |
|
|
|
n_k_documents = [ |
|
task.result() for task in concurrent.futures.as_completed(search_tasks) |
|
] |
|
n_k_documents = flatten(n_k_documents) |
|
n_k_documents = list(set(n_k_documents)) |
|
|
|
logger.info(f"{len(n_k_documents)} documents retrieved successfully") |
|
|
|
if len(n_k_documents) > 0: |
|
|
|
k_documents = n_k_documents[:k] |
|
else: |
|
k_documents = [] |
|
|
|
return k_documents |
|
|
|
def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]: |
|
assert k >= 3, "k should be >= 3" |
|
|
|
def _search_data( |
|
data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery |
|
) -> list[EmbeddedChunk]: |
|
return data_category_odm.search( |
|
query_vector=embedded_query.embedding, |
|
limit=k, |
|
) |
|
|
|
api = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
embedded_query: EmbeddedQuery = EmbeddedQuery( |
|
embedding=api.embeddings.create( |
|
model="text-embedding-3-small", input=query.content |
|
) |
|
.data[0] |
|
.embedding, |
|
id=query.id, |
|
content=query.content, |
|
) |
|
|
|
retrieved_chunks = _search_data(EmbeddedChunk, embedded_query) |
|
logger.info(f"{len(retrieved_chunks)} documents retrieved successfully") |
|
|
|
return retrieved_chunks |
|
|
|
def rerank( |
|
self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int |
|
) -> list[EmbeddedChunk]: |
|
if isinstance(query, str): |
|
query = Query.from_str(query) |
|
|
|
reranked_documents = self._reranker.generate( |
|
query=query, chunks=chunks, keep_top_k=keep_top_k |
|
) |
|
|
|
logger.info(f"{len(reranked_documents)} documents reranked successfully.") |
|
|
|
return reranked_documents |
|
|
|
def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str: |
|
context = "" |
|
for chunk in reranked_chunks: |
|
context += "\n Document: " |
|
context += chunk.content |
|
api = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
answer_generation_template = AnswerGenerationTemplate() |
|
prompt = answer_generation_template.create_template(context, query) |
|
logger.info(prompt) |
|
response = api.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[{"role": "user", "content": prompt}], |
|
max_tokens=8192, |
|
) |
|
return response.choices[0].message.content |
|
|
|
def add_context(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str: |
|
logger.info("Adding context to the answer") |
|
return self._source_annotator.annotate(response, reranked_chunks) |
|
|
|
def rag(self, query: str) -> tuple[str, list[str]]: |
|
query_type = self._query_classifier.generate(Query.from_str(query)) |
|
logger.info(f"Query type: {query_type}") |
|
if query_type == "Sources_needed": |
|
docs = self.search(query, k=10) |
|
else: |
|
docs = [] |
|
|
|
answer = self.generate_answer(query, docs) |
|
|
|
if docs: |
|
annotated_answer = self.add_context(answer, docs) |
|
else: |
|
annotated_answer = answer |
|
|
|
return ( |
|
annotated_answer, |
|
list(set([doc.metadata["filename"].split(".pdf")[0] for doc in docs])), |
|
) |
|
|