from llm_constants import LLM_MODEL_NAME, MAX_TOKENS, RERANKER_MODEL_NAME, EMBEDDINGS_MODEL_NAME, EMBEDDINGS_TOKENS_COST, INPUT_TOKENS_COST, OUTPUT_TOKENS_COST, COHERE_RERANKER_COST from prompts import CHAT_PROMPT, TOOLS import os from langchain_openai import OpenAIEmbeddings from langchain_core.documents import Document from langchain_community.retrievers import BM25Retriever from typing import List, Dict, Sequence from pydantic_models import RequestModel, ResponseModel, ChatHistoryItem, VectorStoreDocumentItem import tiktoken from dotenv import load_dotenv load_dotenv() from langchain_community.vectorstores import FAISS import anthropic import cohere class RAGChatBot: __cohere_api_key = os.getenv("COHERE_API_KEY") __anthroic_api_key = os.getenv("ANTHROPIC_API_KEY") __openai_api_key = os.getenv("OPENAI_API_KEY") __embedding_function = OpenAIEmbeddings(model=EMBEDDINGS_MODEL_NAME) __base_retriever = None __bm25_retriever = None anthropic_client = None cohere_client = None top_n: int = 3 chat_history_length: int = 10 def __init__(self, vectorstore_path:str, top_n:int = 3): if self.__cohere_api_key is None: raise ValueError("COHERE_API_KEY must be set in the environment") if self.__anthroic_api_key is None: raise ValueError("ANTHROPIC_API_KEY must be set in the environment") if self.__openai_api_key is None: raise ValueError("OPENAI_API_KEY must be set in the environment") if not isinstance(top_n, int): raise ValueError("top_n must be an integer") self.top_n = top_n self.set_base_retriever(vectorstore_path) self.set_anthropic_client() self.set_cohere_client() def set_base_retriever(self, vectorstore_path:str): db = FAISS.load_local(vectorstore_path, self.__embedding_function, allow_dangerous_deserialization=True) retriever = db.as_retriever(search_kwargs={"k": 25}) self.__base_retriever = retriever self.__bm25_retriever = BM25Retriever.from_documents(list(db.docstore.__dict__.get('_dict').values()), k=25) def set_anthropic_client(self): self.anthropic_client = anthropic.Anthropic(api_key=self.__anthroic_api_key) def set_cohere_client(self): self.cohere_client = cohere.Client(self.__cohere_api_key) def make_llm_api_call(self, messages:list): return self.anthropic_client.messages.create( model=LLM_MODEL_NAME, max_tokens=MAX_TOKENS, temperature=0, messages=messages, tools=TOOLS ) def make_rerank_api_call(self, search_phrase:str, documents: Sequence[str]): return self.cohere_client.rerank(query=search_phrase, documents=documents, model=RERANKER_MODEL_NAME, top_n=self.top_n) def retrieve_documents(self, search_phrase:str): similarity_documents = self.__base_retriever.invoke(search_phrase) bm25_documents = self.__bm25_retriever.invoke(search_phrase) unique_docs = [] for doc in bm25_documents: if doc not in unique_docs: unique_docs.append(doc) for doc in similarity_documents: if doc not in unique_docs: unique_docs.append(doc) return unique_docs def retrieve_and_rerank(self, search_phrase:str): documents = self.retrieve_documents(search_phrase) if len(documents) == 0: # to avoid empty api call return [] docs = [doc.page_content for doc in documents if isinstance(doc, Document) ] api_result = self.make_rerank_api_call(search_phrase, docs) reranked_docs = [] max_score = max([res.relevance_score for res in api_result.results]) threshold_score = max_score * 0.8 for res in api_result.results: # if res.relevance_score < threshold_score: # continue doc = documents[res.index] documentItem = VectorStoreDocumentItem(page_content=doc.page_content, filename=doc.metadata['filename'], heading=doc.metadata['heading'], relevance_score=res.relevance_score) reranked_docs.append(documentItem) return reranked_docs def get_context_and_docs(self, search_phrase:str): docs = self.retrieve_and_rerank(search_phrase) context = "\n\n\n".join([f"Filename:{doc.heading}\n\n{doc.page_content}" for doc in docs]) return context, docs def get_tool_use_assistant_message(self, tool_use_block): return {'role': 'assistant', 'content':tool_use_block } def get_tool_use_user_message(self, tool_use_id, context): return {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': tool_use_id, 'content': context}]} def process_tool_call(self, tool_name, tool_input): if tool_name == "Documents_Retriever": context, sources_list = self.get_context_and_docs(tool_input["search_phrase"]) search_phrase = tool_input["search_phrase"] return sources_list, search_phrase, context def calculate_cost(self, input_tokens, output_tokens, search_phrase): MILLION = 1000000 if search_phrase: enc = tiktoken.get_encoding("cl100k_base") query_encode = enc.encode(search_phrase) embeddings_cost = len(query_encode) * (EMBEDDINGS_TOKENS_COST/MILLION) total_cost = embeddings_cost + COHERE_RERANKER_COST + (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION)) else: total_cost = (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION)) return total_cost def chat_with_claude(self, user_message_history:list): input_tokens = 0 output_tokens = 0 message = self.make_llm_api_call(user_message_history) input_tokens += message.usage.input_tokens output_tokens += message.usage.output_tokens documents_list = [] search_phrase = "" while message.stop_reason == "tool_use": tool_use = next(block for block in message.content if block.type == "tool_use") tool_name = tool_use.name tool_input = tool_use.input tool_use_id = tool_use.id documents_list, search_phrase, tool_result = self.process_tool_call(tool_name, tool_input) user_message_history.append( self.get_tool_use_assistant_message(message.content)) user_message_history.append( self.get_tool_use_user_message(tool_use_id, tool_result)) message = self.make_llm_api_call(user_message_history) input_tokens += message.usage.input_tokens output_tokens += message.usage.output_tokens answer = next( (block.text for block in message.content if hasattr(block,"text")), None, ) if "" in answer: answer = answer.split("")[1].split("")[0].strip() total_cost = self.calculate_cost(input_tokens, output_tokens, search_phrase) return (documents_list, search_phrase, answer, total_cost) def get_chat_history_text(self, chat_history: List[ChatHistoryItem]): chat_history_text = "" for chat_message in chat_history: chat_history_text += f"User: {chat_message.user_message}\nAssistant: {chat_message.assistant_message}\n" return chat_history_text.strip() def get_response(self, input:RequestModel) -> ResponseModel: chat_history = self.get_chat_history_text(input.chat_history) user_question = input.user_question user_prompt = CHAT_PROMPT.format(CHAT_HISTORY=chat_history, USER_QUESTION=user_question) if input.use_tool: user_prompt = f"{user_prompt}\nUse Documents_Retriever tool in your response." sources_list, search_phrase, answer, _ = self.chat_with_claude([{"role":"user","content":[{"type":"text","text":user_prompt}]}]) updated_chat_history = input.chat_history.copy() updated_chat_history.append(ChatHistoryItem(user_message=user_question, assistant_message=answer)) return ResponseModel(answer = answer, sources_documents = sources_list, chat_history=updated_chat_history, search_phrase=search_phrase)