chatbot_app / app.py
eliot-hub's picture
theme
aa67bf4 verified
raw
history blame
5.87 kB
import os
from dotenv import load_dotenv
import gradio as gr
from langchain_chroma import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import ChatOpenAI
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_huggingface import HuggingFaceEmbeddings
from datasets import load_dataset
import chromadb
from typing import List
from mixedbread_ai.client import MixedbreadAI
from tqdm import tqdm
# Global params
CHROMA_PATH = "chromadb_mem10_mxbai_800_complete"
MODEL_EMB = "mxbai-embed-large"
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
LLM_NAME = "gpt-4o-mini"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API_KEY = os.environ.get("HF_API_KEY")
# MixedbreadAI Client
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
model_emb = "mixedbread-ai/mxbai-embed-large-v1"
# Set up ChromaDB
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
batched_ds = memoires_ds.batch(batch_size=41000)
client = chromadb.Client()
collection = client.get_or_create_collection(name="embeddings_mxbai")
for batch in tqdm(batched_ds, desc="Processing dataset batches"):
collection.upsert(
ids=batch["id"],
metadatas=batch["metadata"],
documents=batch["document"],
embeddings=batch["embedding"],
)
print(f"Collection complete: {collection.count()}")
db = Chroma(
client=client,
collection_name=f"embeddings_mxbai",
embedding_function = HuggingFaceEmbeddings(model_name=model_emb)
)
# Reranker class
class Reranker(BaseRetriever):
retriever: VectorStoreRetriever
# model: CrossEncoder
k: int
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs = self.retriever.invoke(query)
results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
return [Document(page_content=res.input) for res in results.data]
# Set up reranker + LLM
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, )
# Set up the contextualize question prompt
contextualize_q_system_prompt = (
"Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur "
"qui peut faire référence à un contexte dans l'historique du chat, "
"formuler une question autonome qui peut être comprise "
"sans l'historique du chat. Ne répondez PAS à la question, "
"juste la reformuler si nécessaire et sinon la renvoyer telle quelle."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# Create the history-aware retriever
history_aware_retriever = create_history_aware_retriever(
llm, reranker, contextualize_q_prompt
)
# Set up the QA prompt
system_prompt = (
"Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
"Si tu ne connais pas la réponse, dis que tu ne sais pas."
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Set up the conversation history
store = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
# Gradio interface
def chatbot(message, history):
session_id = "gradio_session"
response = conversational_rag_chain.invoke(
{"input": message},
config={
"configurable": {"session_id": session_id},
"callbacks": [ConsoleCallbackHandler()]
},
)["answer"]
return response
iface = gr.ChatInterface(
chatbot,
title="Dataltist Chatbot",
description="Posez vos questions sur l'assurance",
textbox=gr.Textbox(placeholder="Qu'est-ce que l'assurance multirisque habitation ?", container=False, scale=9),
theme=gr.themes.Soft(primary_hue="red", secondary_hue="pink"),
# examples=[
# "Qu'est-ce que l'assurance multirisque habitation ?",
# "Qu'est-ce que la garantie DTA ?",
# ],
retry_btn=None,
undo_btn=None,
submit_btn=gr.Button(value="Envoyer", icon="./send_icon.png", variant="primary"),
clear_btn="Effacer la conversation",
)
if __name__ == "__main__":
iface.launch() # share=True