Spaces:
Sleeping
Sleeping
from llm_constants import LLM_MODEL_NAME, MAX_TOKENS, RERANKER_MODEL_NAME, EMBEDDINGS_MODEL_NAME, EMBEDDINGS_TOKENS_COST, INPUT_TOKENS_COST, OUTPUT_TOKENS_COST, COHERE_RERANKER_COST | |
from prompts import CHAT_PROMPT, TOOLS | |
import os | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_core.documents import Document | |
from langchain_community.retrievers import BM25Retriever | |
from typing import List, Dict, Sequence | |
from pydantic_models import RequestModel, ResponseModel, ChatHistoryItem, VectorStoreDocumentItem | |
import tiktoken | |
from dotenv import load_dotenv | |
load_dotenv() | |
from langchain_community.vectorstores import FAISS | |
import anthropic | |
import cohere | |
class RAGChatBot: | |
__cohere_api_key = os.getenv("COHERE_API_KEY") | |
__anthroic_api_key = os.getenv("ANTHROPIC_API_KEY") | |
__openai_api_key = os.getenv("OPENAI_API_KEY") | |
__embedding_function = OpenAIEmbeddings(model=EMBEDDINGS_MODEL_NAME) | |
__base_retriever = None | |
__bm25_retriever = None | |
anthropic_client = None | |
cohere_client = None | |
top_n: int = 3 | |
chat_history_length: int = 10 | |
def __init__(self, vectorstore_path:str, top_n:int = 3): | |
if self.__cohere_api_key is None: | |
raise ValueError("COHERE_API_KEY must be set in the environment") | |
if self.__anthroic_api_key is None: | |
raise ValueError("ANTHROPIC_API_KEY must be set in the environment") | |
if self.__openai_api_key is None: | |
raise ValueError("OPENAI_API_KEY must be set in the environment") | |
if not isinstance(top_n, int): | |
raise ValueError("top_n must be an integer") | |
self.top_n = top_n | |
self.set_base_retriever(vectorstore_path) | |
self.set_anthropic_client() | |
self.set_cohere_client() | |
def set_base_retriever(self, vectorstore_path:str): | |
db = FAISS.load_local(vectorstore_path, self.__embedding_function, allow_dangerous_deserialization=True) | |
retriever = db.as_retriever(search_kwargs={"k": 25}) | |
self.__base_retriever = retriever | |
self.__bm25_retriever = BM25Retriever.from_documents(list(db.docstore.__dict__.get('_dict').values()), k=25) | |
def set_anthropic_client(self): | |
self.anthropic_client = anthropic.Anthropic(api_key=self.__anthroic_api_key) | |
def set_cohere_client(self): | |
self.cohere_client = cohere.Client(self.__cohere_api_key) | |
def make_llm_api_call(self, messages:list): | |
return self.anthropic_client.messages.create( | |
model=LLM_MODEL_NAME, | |
max_tokens=MAX_TOKENS, | |
temperature=0, | |
messages=messages, | |
tools=TOOLS | |
) | |
def make_rerank_api_call(self, search_phrase:str, documents: Sequence[str]): | |
return self.cohere_client.rerank(query=search_phrase, documents=documents, model=RERANKER_MODEL_NAME, top_n=self.top_n) | |
def retrieve_documents(self, search_phrase:str): | |
similarity_documents = self.__base_retriever.invoke(search_phrase) | |
bm25_documents = self.__bm25_retriever.invoke(search_phrase) | |
unique_docs = [] | |
for doc in bm25_documents: | |
if doc not in unique_docs: | |
unique_docs.append(doc) | |
for doc in similarity_documents: | |
if doc not in unique_docs: | |
unique_docs.append(doc) | |
return unique_docs | |
def retrieve_and_rerank(self, search_phrase:str): | |
documents = self.retrieve_documents(search_phrase) | |
if len(documents) == 0: # to avoid empty api call | |
return [] | |
docs = [doc.page_content for doc in documents if isinstance(doc, Document) ] | |
api_result = self.make_rerank_api_call(search_phrase, docs) | |
reranked_docs = [] | |
max_score = max([res.relevance_score for res in api_result.results]) | |
threshold_score = max_score * 0.8 | |
for res in api_result.results: | |
# if res.relevance_score < threshold_score: | |
# continue | |
doc = documents[res.index] | |
documentItem = VectorStoreDocumentItem(page_content=doc.page_content, filename=doc.metadata['filename'], heading=doc.metadata['heading'], relevance_score=res.relevance_score) | |
reranked_docs.append(documentItem) | |
return reranked_docs | |
def get_context_and_docs(self, search_phrase:str): | |
docs = self.retrieve_and_rerank(search_phrase) | |
context = "\n\n\n".join([f"Filename:{doc.heading}\n\n{doc.page_content}" for doc in docs]) | |
return context, docs | |
def get_tool_use_assistant_message(self, tool_use_block): | |
return {'role': 'assistant', | |
'content':tool_use_block | |
} | |
def get_tool_use_user_message(self, tool_use_id, context): | |
return {'role': 'user', | |
'content': [{'type': 'tool_result', | |
'tool_use_id': tool_use_id, | |
'content': context}]} | |
def process_tool_call(self, tool_name, tool_input): | |
if tool_name == "Documents_Retriever": | |
context, sources_list = self.get_context_and_docs(tool_input["search_phrase"]) | |
search_phrase = tool_input["search_phrase"] | |
return sources_list, search_phrase, context | |
def calculate_cost(self, input_tokens, output_tokens, search_phrase): | |
MILLION = 1000000 | |
if search_phrase: | |
enc = tiktoken.get_encoding("cl100k_base") | |
query_encode = enc.encode(search_phrase) | |
embeddings_cost = len(query_encode) * (EMBEDDINGS_TOKENS_COST/MILLION) | |
total_cost = embeddings_cost + COHERE_RERANKER_COST + (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION)) | |
else: | |
total_cost = (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION)) | |
return total_cost | |
def chat_with_claude(self, user_message_history:list): | |
input_tokens = 0 | |
output_tokens = 0 | |
message = self.make_llm_api_call(user_message_history) | |
input_tokens += message.usage.input_tokens | |
output_tokens += message.usage.output_tokens | |
documents_list = [] | |
search_phrase = "" | |
while message.stop_reason == "tool_use": | |
tool_use = next(block for block in message.content if block.type == "tool_use") | |
tool_name = tool_use.name | |
tool_input = tool_use.input | |
tool_use_id = tool_use.id | |
documents_list, search_phrase, tool_result = self.process_tool_call(tool_name, tool_input) | |
user_message_history.append( self.get_tool_use_assistant_message(message.content)) | |
user_message_history.append( self.get_tool_use_user_message(tool_use_id, tool_result)) | |
message = self.make_llm_api_call(user_message_history) | |
input_tokens += message.usage.input_tokens | |
output_tokens += message.usage.output_tokens | |
answer = next( | |
(block.text for block in message.content if hasattr(block,"text")), | |
None, | |
) | |
if "<answer>" in answer: | |
answer = answer.split("<answer>")[1].split("</answer>")[0].strip() | |
total_cost = self.calculate_cost(input_tokens, output_tokens, search_phrase) | |
return (documents_list, search_phrase, answer, total_cost) | |
def get_chat_history_text(self, chat_history: List[ChatHistoryItem]): | |
chat_history_text = "" | |
for chat_message in chat_history: | |
chat_history_text += f"User: {chat_message.user_message}\nAssistant: {chat_message.assistant_message}\n" | |
return chat_history_text.strip() | |
def get_response(self, input:RequestModel) -> ResponseModel: | |
chat_history = self.get_chat_history_text(input.chat_history) | |
user_question = input.user_question | |
user_prompt = CHAT_PROMPT.format(CHAT_HISTORY=chat_history, USER_QUESTION=user_question) | |
if input.use_tool: | |
user_prompt = f"{user_prompt}\nUse Documents_Retriever tool in your response." | |
sources_list, search_phrase, answer, _ = self.chat_with_claude([{"role":"user","content":[{"type":"text","text":user_prompt}]}]) | |
updated_chat_history = input.chat_history.copy() | |
updated_chat_history.append(ChatHistoryItem(user_message=user_question, assistant_message=answer)) | |
return ResponseModel(answer = answer, sources_documents = sources_list, chat_history=updated_chat_history, search_phrase=search_phrase) | |