import os from dotenv import load_dotenv import gradio as gr from langchain_chroma import Chroma from langchain.prompts import ChatPromptTemplate from langchain.chains import create_retrieval_chain, create_history_aware_retriever from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import MessagesPlaceholder from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.vectorstores import VectorStoreRetriever from langchain_openai import ChatOpenAI from langchain.callbacks.tracers import ConsoleCallbackHandler from langchain_huggingface import HuggingFaceEmbeddings from datasets import load_dataset import chromadb from typing import List from mixedbread_ai.client import MixedbreadAI from tqdm import tqdm # Global params CHROMA_PATH = "chromadb_mem10_mxbai_800_complete" MODEL_EMB = "mxbai-embed-large" MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" LLM_NAME = "gpt-4o-mini" OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") HF_API_KEY = os.environ.get("HF_API_KEY") # MixedbreadAI Client # device = "cuda:0" if torch.cuda.is_available() else "cpu" mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) model_emb = "mixedbread-ai/mxbai-embed-large-v1" # Set up ChromaDB memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True) batched_ds = memoires_ds.batch(batch_size=41000) client = chromadb.Client() collection = client.get_or_create_collection(name="embeddings_mxbai") for batch in tqdm(batched_ds, desc="Processing dataset batches"): collection.upsert( ids=batch["id"], metadatas=batch["metadata"], documents=batch["document"], embeddings=batch["embedding"], ) print(f"Collection complete: {collection.count()}") db = Chroma( client=client, collection_name=f"embeddings_mxbai", embedding_function = HuggingFaceEmbeddings(model_name=model_emb) ) # Reranker class class Reranker(BaseRetriever): retriever: VectorStoreRetriever # model: CrossEncoder k: int def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: docs = self.retriever.invoke(query) results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) return [Document(page_content=res.input) for res in results.data] # Set up reranker + LLM retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, ) # Set up the contextualize question prompt contextualize_q_system_prompt = ( "Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur " "qui peut faire référence à un contexte dans l'historique du chat, " "formuler une question autonome qui peut être comprise " "sans l'historique du chat. Ne répondez PAS à la question, " "juste la reformuler si nécessaire et sinon la renvoyer telle quelle." ) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) # Create the history-aware retriever history_aware_retriever = create_history_aware_retriever( llm, reranker, contextualize_q_prompt ) # Set up the QA prompt system_prompt = ( "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" "Si tu ne connais pas la réponse, dis que tu ne sais pas." ) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) # Create the question-answer chain question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) # Set up the conversation history store = {} def get_session_history(session_id: str) -> ChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) # Gradio interface def chatbot(message, history): session_id = "gradio_session" response = conversational_rag_chain.invoke( {"input": message}, config={ "configurable": {"session_id": session_id}, "callbacks": [ConsoleCallbackHandler()] }, )["answer"] return response iface = gr.ChatInterface( chatbot, title="Dataltist Chatbot", description="Posez vos questions sur l'assurance", textbox=gr.Textbox(placeholder="Qu'est-ce que l'assurance multirisque habitation ?", container=False, scale=9), theme=gr.themes.Soft(primary_hue="red", secondary_hue="pink"), # examples=[ # "Qu'est-ce que l'assurance multirisque habitation ?", # "Qu'est-ce que la garantie DTA ?", # ], retry_btn=None, undo_btn=None, submit_btn=gr.Button(value="Envoyer", icon="./send_icon.png", variant="primary"), clear_btn="Effacer la conversation", ) if __name__ == "__main__": iface.launch() # share=True