|
from langchain.callbacks import FileCallbackHandler |
|
from langchain_community.chat_models import ChatOllama |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from loguru import logger |
|
|
|
from rag_101.retriever import ( |
|
RAGException, |
|
create_parent_retriever, |
|
load_embedding_model, |
|
load_pdf, |
|
load_reranker_model, |
|
retrieve_context, |
|
) |
|
|
|
|
|
class RAGClient: |
|
embedding_model = load_embedding_model() |
|
reranker_model = load_reranker_model() |
|
|
|
def __init__(self, files, model="mistral"): |
|
docs = load_pdf(files=files) |
|
self.retriever = create_parent_retriever(docs, self.embedding_model) |
|
|
|
llm = ChatOllama(model=model) |
|
prompt_template = ChatPromptTemplate.from_template( |
|
( |
|
"Please answer the following question based on the provided `context` that follows the question.\n" |
|
"Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n" |
|
"question: {question}\n" |
|
"context: ```{context}```\n" |
|
) |
|
) |
|
self.chain = prompt_template | llm | StrOutputParser() |
|
|
|
def stream(self, query: str) -> dict: |
|
try: |
|
context, similarity_score = self.retrieve_context(query)[0] |
|
context = context.page_content |
|
if similarity_score < 0.005: |
|
context = "This context is not confident. " + context |
|
except RAGException as e: |
|
context, similarity_score = e.args[0], 0 |
|
logger.info(context) |
|
for r in self.chain.stream({"context": context, "question": query}): |
|
yield r |
|
|
|
def retrieve_context(self, query: str): |
|
return retrieve_context( |
|
query, retriever=self.retriever, reranker_model=self.reranker_model |
|
) |
|
|
|
def generate(self, query: str) -> dict: |
|
contexts = self.retrieve_context(query) |
|
|
|
return { |
|
"contexts": contexts, |
|
"response": self.chain.invoke( |
|
{"context": contexts[0][0].page_content, "question": query} |
|
), |
|
} |