File size: 5,867 Bytes
de77992
 
 
42f87c6
de77992
1bea5ac
42f87c6
 
 
 
 
 
 
 
 
 
 
 
 
de77992
42f87c6
de77992
 
 
31d0102
42f87c6
 
 
 
 
 
 
 
543c1bb
9c5d425
42f87c6
de77992
 
42f87c6
 
1bea5ac
42f87c6
31d0102
8216547
42f87c6
de77992
42f87c6
de77992
846bd95
31d0102
 
 
 
42f87c6
31d0102
9e84bb1
1bea5ac
 
 
 
 
42f87c6
 
 
 
 
 
 
 
 
 
 
 
31d0102
42f87c6
 
 
 
 
1bea5ac
42f87c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15ec7ce
42f87c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9082c6d
f069c91
9082c6d
15ec7ce
aa67bf4
15ec7ce
 
 
 
42f87c6
 
7430b5c
42f87c6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
from dotenv import load_dotenv

import gradio as gr

from langchain_chroma import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import ChatOpenAI
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_huggingface import HuggingFaceEmbeddings 

from datasets import load_dataset
import chromadb
from typing import List
from mixedbread_ai.client import MixedbreadAI
from tqdm import tqdm

# Global params
CHROMA_PATH = "chromadb_mem10_mxbai_800_complete"
MODEL_EMB = "mxbai-embed-large"
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
LLM_NAME = "gpt-4o-mini"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API_KEY = os.environ.get("HF_API_KEY")

# MixedbreadAI Client
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
model_emb = "mixedbread-ai/mxbai-embed-large-v1"  

# Set up ChromaDB
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
batched_ds = memoires_ds.batch(batch_size=41000)
client = chromadb.Client()
collection = client.get_or_create_collection(name="embeddings_mxbai") 

for batch in tqdm(batched_ds, desc="Processing dataset batches"): 
    collection.upsert(
        ids=batch["id"],
        metadatas=batch["metadata"],
        documents=batch["document"],
        embeddings=batch["embedding"],      
    )
print(f"Collection complete: {collection.count()}")

db = Chroma(
    client=client,
    collection_name=f"embeddings_mxbai",
    embedding_function = HuggingFaceEmbeddings(model_name=model_emb) 
)


# Reranker class
class Reranker(BaseRetriever):
    retriever: VectorStoreRetriever
    # model: CrossEncoder
    k: int

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        docs = self.retriever.invoke(query)
        results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
        return [Document(page_content=res.input) for res in results.data]

# Set up reranker + LLM
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
reranker = Reranker(retriever=retriever, k=4)  #Reranker(retriever=retriever, model=model, k=4)
llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, )

# Set up the contextualize question prompt
contextualize_q_system_prompt = (
    "Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur "
    "qui peut faire référence à un contexte dans l'historique du chat, "
    "formuler une question autonome qui peut être comprise "
    "sans l'historique du chat. Ne répondez PAS à la question, "
    "juste la reformuler si nécessaire et sinon la renvoyer telle quelle."
)

contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)

# Create the history-aware retriever
history_aware_retriever = create_history_aware_retriever(
    llm, reranker, contextualize_q_prompt
)

# Set up the QA prompt
system_prompt = (
    "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
    "Si tu ne connais pas la réponse, dis que tu ne sais pas."
)
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)

# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

# Set up the conversation history
store = {}

def get_session_history(session_id: str) -> ChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

# Gradio interface
def chatbot(message, history):
    session_id = "gradio_session"
    response = conversational_rag_chain.invoke(
        {"input": message},
        config={
            "configurable": {"session_id": session_id},
            "callbacks": [ConsoleCallbackHandler()]
        },
    )["answer"]
    return response

iface = gr.ChatInterface(
    chatbot,
    title="Dataltist Chatbot",
    description="Posez vos questions sur l'assurance",
    textbox=gr.Textbox(placeholder="Qu'est-ce que l'assurance multirisque habitation ?", container=False, scale=9),
    theme=gr.themes.Soft(primary_hue="red", secondary_hue="pink"),
    # examples=[
    #     "Qu'est-ce que l'assurance multirisque habitation ?",
    #     "Qu'est-ce que la garantie DTA ?",
    # ],
    retry_btn=None,
    undo_btn=None,
    submit_btn=gr.Button(value="Envoyer", icon="./send_icon.png", variant="primary"),
    clear_btn="Effacer la conversation",
)

if __name__ == "__main__":
    iface.launch()  # share=True