abdelom commited on
Commit
9ba08dd
·
verified ·
1 Parent(s): 98894ba

Update pages/1_Chatbot_FR.py

Browse files
Files changed (1) hide show
  1. pages/1_Chatbot_FR.py +170 -44
pages/1_Chatbot_FR.py CHANGED
@@ -3,22 +3,17 @@ import pandas as pd
3
  import os
4
  from pathlib import Path
5
  import base64
 
 
 
6
 
7
- # LangChain & Hugging Face
8
- from langchain.embeddings import HuggingFaceEmbeddings
9
- from langchain.vectorstores import Chroma
10
- from langchain.schema import Document
11
- from langchain.prompts import PromptTemplate
12
- from langchain.llms import HuggingFaceHub
13
- from langchain.chains import LLMChain
14
-
15
  import pysqlite3
16
- import sys
17
  sys.modules["sqlite3"] = pysqlite3
18
 
19
- #####################
20
- # 1. HELPER FUNCTIONS
21
- #####################
22
 
23
  def get_base64_of_bin_file(bin_file_path: str) -> str:
24
  file_bytes = Path(bin_file_path).read_bytes()
@@ -92,28 +87,29 @@ def create_contextual_fr(df, category, strat_id=0):
92
  def load_excel_and_create_vectorstore_fr(excel_path: str, persist_dir: str = "./chroma_db_fr"):
93
  """
94
  Charge les données depuis plusieurs feuilles Excel (version FR),
95
- construit & stocke un Chroma VectorStore.
96
  """
97
- # 1. Charger les feuilles Excel
98
  qna_tree_fr0 = pd.read_excel(excel_path, sheet_name="Prépayé (FR)", skiprows=1).iloc[:, :5]
99
  qna_tree_fr1 = pd.read_excel(excel_path, sheet_name="Postpayé (FR)", skiprows=1).iloc[:, :5]
100
  qna_tree_fr2 = pd.read_excel(excel_path, sheet_name="Wifi (FR)", skiprows=1).iloc[:, :5]
101
 
102
- # 2. Construire le contexte
103
  context_fr0 = create_contextual_fr(qna_tree_fr0, "Prépayé", strat_id = 0)
104
  context_fr1 = create_contextual_fr(qna_tree_fr1, "Postpayé", strat_id = len(context_fr0))
105
  context_fr2 = create_contextual_fr(qna_tree_fr2, "Wifi", strat_id = len(context_fr0) + len(context_fr1))
106
 
107
- # 3. Concaténer les DataFrame
108
  context_fr = pd.concat([context_fr0, context_fr1, context_fr2], axis=0)
109
 
110
- # 4. Créer une colonne "context"
111
  context_fr["context"] = context_fr.apply(
112
  lambda row: f"{row['question']} > {row['answer']}",
113
  axis=1
114
  )
115
 
116
- # 5. Convertir chaque ligne en Document
 
117
  documents_fr = [
118
  Document(
119
  page_content=row["context"],
@@ -122,7 +118,9 @@ def load_excel_and_create_vectorstore_fr(excel_path: str, persist_dir: str = "./
122
  for _, row in context_fr.iterrows()
123
  ]
124
 
125
- # 6. Créer & persister le vecteur
 
 
126
  embedding_model_fr = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
127
  vectorstore_fr = Chroma.from_documents(documents_fr, embedding_model_fr, persist_directory=persist_dir)
128
  vectorstore_fr.persist()
@@ -133,6 +131,8 @@ def load_existing_vectorstore_fr(persist_dir: str = "./chroma_db_fr"):
133
  """
134
  Charge un VectorStore Chroma déjà stocké (version FR).
135
  """
 
 
136
  embedding_model_fr = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
137
  vectorstore_fr = Chroma(
138
  persist_directory=persist_dir,
@@ -150,10 +150,105 @@ def retrieve_context_fr(retriever_fr, query, top_k=5):
150
  context_fr_list.append(result.page_content)
151
  return context_fr_list
152
 
 
 
 
153
 
154
- #########################
155
- # 2. PROMPT & LLM FR #
156
- #########################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  prompt_template_fr = PromptTemplate(
159
  input_variables=["context", "query"],
@@ -162,7 +257,7 @@ prompt_template_fr = PromptTemplate(
162
  Vous êtes un assistant client professionnel, expérimenté et bienveillant pour l'opérateur téléphonique INWI.
163
  Vous excellez dans la gestion des clients, en répondant à leurs problèmes et questions.
164
  Fournir un service client et des conseils en se basant sur les contextes fournis :
165
- - Répondre aux salutations de manière courtoise et amicale, par exemple : "Bonjour! Je suis l'assistant IA d'INWI'. Comment puis-je vous aider aujourd'hui ?"
166
  - Identifier le besoin du client et demander des clarifications si nécessaire, tout en s'appuyant uniquement sur le contexte.
167
  - Si la question n'est pas liée au contexte d'INWI, veuillez informer poliment que vous ne pouvez pas répondre à des questions hors contexte INWI.
168
  - Si la réponse ne figure pas dans le contexte, vous pouvez dire "Je n'ai pas assez d'information" et proposer d'appeler le service client au 120.
@@ -189,17 +284,13 @@ pour les particuliers et les entreprises. Il se distingue par son engagement à
189
  accessibles, tout en contribuant au développement numérique du pays.
190
  Les clients sont notre priorité, et notre but est de résoudre leurs problèmes.
191
  Votre rôle est de fournir un service client professionnel et efficace sans inventer d'informations.
192
-
193
  [CONTEXTE]
194
  {context}
195
-
196
  [QUESTION DU CLIENT]
197
  {query}
198
-
199
  [RÉPONSE]"""
200
  )
201
  )
202
- # Configuration du LLM HuggingFace (FR)
203
 
204
  llm_fr = HuggingFaceHub(
205
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
@@ -210,18 +301,16 @@ llm_fr = HuggingFaceHub(
210
  }
211
  )
212
 
213
- # Chaîne FR
214
  llm_chain_fr = LLMChain(llm=llm_fr, prompt=prompt_template_fr)
215
 
216
-
217
- #########################
218
- # 3. STREAMLIT MAIN APP #
219
- #########################
220
 
221
  def main():
222
  st.subheader("INWI IA Chatbot - Français")
223
 
224
- # Read local image and convert to Base64
225
  img_base64 = get_base64_of_bin_file("./img/logo inwi celeverlytics.png")
226
  css_logo = f"""
227
  <style>
@@ -238,10 +327,9 @@ def main():
238
  }}
239
  </style>
240
  """
241
-
242
  st.markdown(css_logo, unsafe_allow_html=True)
243
 
244
- # Charger ou créer le retriever
245
  if "retriever_fr" not in st.session_state:
246
  st.session_state["retriever_fr"] = None
247
 
@@ -274,7 +362,7 @@ def main():
274
  st.write("""Je suis là pour répondre à toutes vos questions concernant nos
275
  services, nos offres mobiles et Internet, ainsi que nos solutions adaptées à vos besoins (FR).""")
276
 
277
- # Zone de texte
278
  user_query_fr = st.chat_input("Posez votre question ici (FR)...")
279
 
280
  if user_query_fr:
@@ -282,23 +370,61 @@ def main():
282
  st.warning("Veuillez d'abord créer ou charger la Vector Store (FR).")
283
  return
284
 
285
- # Récupération du contexte
286
  context_fr_list = retrieve_context_fr(st.session_state["retriever_fr"], user_query_fr, top_k=5)
287
 
288
  if context_fr_list:
289
  with st.spinner("Génération de la réponse..."):
290
- response_fr = llm_chain_fr.run({"context": "\n".join(context_fr_list), "query": user_query_fr + "?"})
291
- # Séparer si jamais le prompt contient [RÉPONSE], sinon on affiche tout
292
- response_fr = response_fr.split("[RÉPONSE]")[-1]
 
 
 
 
 
293
  st.write("**Question :**")
294
  st.write(user_query_fr)
295
- st.write("**Réponse :**")
296
  st.write(response_fr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  else:
298
  st.write("Aucun contexte trouvé pour cette question. Essayez autre chose.")
299
 
300
-
301
  if __name__ == "__main__":
302
  main()
303
-
304
-
 
3
  import os
4
  from pathlib import Path
5
  import base64
6
+ import sys
7
+ import torch
8
+ from transformers import BertForSequenceClassification, BertTokenizer
9
 
10
+ # Force using pysqlite3 if needed
 
 
 
 
 
 
 
11
  import pysqlite3
 
12
  sys.modules["sqlite3"] = pysqlite3
13
 
14
+ ##############################
15
+ # 1. HELPER FUNCTIONS (for the main chatbot)
16
+ ##############################
17
 
18
  def get_base64_of_bin_file(bin_file_path: str) -> str:
19
  file_bytes = Path(bin_file_path).read_bytes()
 
87
  def load_excel_and_create_vectorstore_fr(excel_path: str, persist_dir: str = "./chroma_db_fr"):
88
  """
89
  Charge les données depuis plusieurs feuilles Excel (version FR),
90
+ construit & stocke un Chroma VectorStore pour le chatbot.
91
  """
92
+ # Charger les feuilles Excel
93
  qna_tree_fr0 = pd.read_excel(excel_path, sheet_name="Prépayé (FR)", skiprows=1).iloc[:, :5]
94
  qna_tree_fr1 = pd.read_excel(excel_path, sheet_name="Postpayé (FR)", skiprows=1).iloc[:, :5]
95
  qna_tree_fr2 = pd.read_excel(excel_path, sheet_name="Wifi (FR)", skiprows=1).iloc[:, :5]
96
 
97
+ # Construire le contexte
98
  context_fr0 = create_contextual_fr(qna_tree_fr0, "Prépayé", strat_id = 0)
99
  context_fr1 = create_contextual_fr(qna_tree_fr1, "Postpayé", strat_id = len(context_fr0))
100
  context_fr2 = create_contextual_fr(qna_tree_fr2, "Wifi", strat_id = len(context_fr0) + len(context_fr1))
101
 
102
+ # Concaténer les DataFrame
103
  context_fr = pd.concat([context_fr0, context_fr1, context_fr2], axis=0)
104
 
105
+ # Créer une colonne "context"
106
  context_fr["context"] = context_fr.apply(
107
  lambda row: f"{row['question']} > {row['answer']}",
108
  axis=1
109
  )
110
 
111
+ # Convertir chaque ligne en Document (pour Chroma)
112
+ from langchain.schema import Document
113
  documents_fr = [
114
  Document(
115
  page_content=row["context"],
 
118
  for _, row in context_fr.iterrows()
119
  ]
120
 
121
+ # Créer & persister le vector store
122
+ from langchain.embeddings import HuggingFaceEmbeddings
123
+ from langchain.vectorstores import Chroma
124
  embedding_model_fr = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
125
  vectorstore_fr = Chroma.from_documents(documents_fr, embedding_model_fr, persist_directory=persist_dir)
126
  vectorstore_fr.persist()
 
131
  """
132
  Charge un VectorStore Chroma déjà stocké (version FR).
133
  """
134
+ from langchain.embeddings import HuggingFaceEmbeddings
135
+ from langchain.vectorstores import Chroma
136
  embedding_model_fr = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
137
  vectorstore_fr = Chroma(
138
  persist_directory=persist_dir,
 
150
  context_fr_list.append(result.page_content)
151
  return context_fr_list
152
 
153
+ ##############################
154
+ # 2. CLASSIFICATION MODEL SETUP
155
+ ##############################
156
 
157
+ # Specify the path where the classification model and tokenizer are saved.
158
+ MODEL_PATH = "saved_bert_model_v1"
159
+
160
+ # Load the tokenizer and model for sequence classification.
161
+ tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
162
+ model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
163
+ model.eval() # Set model to evaluation mode
164
+
165
+ # Use GPU if available.
166
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167
+ model.to(device)
168
+
169
+ def predict_class(text, max_length=500):
170
+ """
171
+ Predicts the class (as string) for a given input text.
172
+ """
173
+ inputs = tokenizer(
174
+ text,
175
+ add_special_tokens=True,
176
+ max_length=max_length,
177
+ padding='max_length',
178
+ truncation=True,
179
+ return_tensors="pt"
180
+ )
181
+ # Move inputs to the device.
182
+ inputs = {k: v.to(device) for k, v in inputs.items()}
183
+
184
+ with torch.no_grad():
185
+ outputs = model(**inputs)
186
+
187
+ logits = outputs.logits
188
+ predicted_class_id = torch.argmax(logits, dim=1).item()
189
+ predicted_label = model.config.id2label[predicted_class_id]
190
+ return predicted_label
191
+
192
+ ##############################
193
+ # 3. CLASSIFICATION DATASET & VECTOR STORE
194
+ ##############################
195
+
196
+ @st.cache_data(show_spinner=False)
197
+ def load_classification_dataset():
198
+ """
199
+ Loads the classification Q&A dataset from the Excel file and returns a DataFrame.
200
+ """
201
+ df = pd.read_excel("Classification dataset - Q&A.xlsx", sheet_name="Fr")
202
+ return df
203
+
204
+ @st.cache_resource(show_spinner=False)
205
+ def load_classification_vectorstore(persist_dir: str = "./chroma_db_class_fr"):
206
+ """
207
+ Builds (and persists) a Chroma vector store from the classification Q&A dataset.
208
+ Each document contains the answer (Réponse) with metadata including the class ("Classe").
209
+ """
210
+ df = load_classification_dataset()
211
+ # Create documents using the "Réponse" as content and include metadata.
212
+ from langchain.schema import Document
213
+ documents = []
214
+ for _, row in df.iterrows():
215
+ documents.append(
216
+ Document(
217
+ page_content=row["Réponse"],
218
+ metadata={
219
+ "id": row["ID"],
220
+ "Classe": row["Classe"],
221
+ "Question": row["Question"]
222
+ }
223
+ )
224
+ )
225
+ from langchain.embeddings import HuggingFaceEmbeddings
226
+ from langchain.vectorstores import Chroma
227
+ embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
228
+ vectorstore = Chroma.from_documents(documents, embedding_model, persist_directory=persist_dir)
229
+ vectorstore.persist()
230
+ return vectorstore
231
+
232
+ def load_existing_classification_vectorstore(persist_dir: str = "./chroma_db_class_fr"):
233
+ """
234
+ Loads an existing Chroma vector store for the classification dataset.
235
+ """
236
+ from langchain.embeddings import HuggingFaceEmbeddings
237
+ from langchain.vectorstores import Chroma
238
+ embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
239
+ vectorstore = Chroma(
240
+ persist_directory=persist_dir,
241
+ embedding_function=embedding_model
242
+ )
243
+ return vectorstore
244
+
245
+ ##############################
246
+ # 4. PROMPT & LLM FR SETUP
247
+ ##############################
248
+
249
+ from langchain.prompts import PromptTemplate
250
+ from langchain.llms import HuggingFaceHub
251
+ from langchain.chains import LLMChain
252
 
253
  prompt_template_fr = PromptTemplate(
254
  input_variables=["context", "query"],
 
257
  Vous êtes un assistant client professionnel, expérimenté et bienveillant pour l'opérateur téléphonique INWI.
258
  Vous excellez dans la gestion des clients, en répondant à leurs problèmes et questions.
259
  Fournir un service client et des conseils en se basant sur les contextes fournis :
260
+ - Répondre aux salutations de manière courtoise et amicale, par exemple : "Bonjour! Je suis l'assistant IA d'INWI. Comment puis-je vous aider aujourd'hui ?"
261
  - Identifier le besoin du client et demander des clarifications si nécessaire, tout en s'appuyant uniquement sur le contexte.
262
  - Si la question n'est pas liée au contexte d'INWI, veuillez informer poliment que vous ne pouvez pas répondre à des questions hors contexte INWI.
263
  - Si la réponse ne figure pas dans le contexte, vous pouvez dire "Je n'ai pas assez d'information" et proposer d'appeler le service client au 120.
 
284
  accessibles, tout en contribuant au développement numérique du pays.
285
  Les clients sont notre priorité, et notre but est de résoudre leurs problèmes.
286
  Votre rôle est de fournir un service client professionnel et efficace sans inventer d'informations.
 
287
  [CONTEXTE]
288
  {context}
 
289
  [QUESTION DU CLIENT]
290
  {query}
 
291
  [RÉPONSE]"""
292
  )
293
  )
 
294
 
295
  llm_fr = HuggingFaceHub(
296
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
 
301
  }
302
  )
303
 
 
304
  llm_chain_fr = LLMChain(llm=llm_fr, prompt=prompt_template_fr)
305
 
306
+ ##############################
307
+ # 5. STREAMLIT MAIN APP
308
+ ##############################
 
309
 
310
  def main():
311
  st.subheader("INWI IA Chatbot - Français")
312
 
313
+ # Sidebar: add logo image.
314
  img_base64 = get_base64_of_bin_file("./img/logo inwi celeverlytics.png")
315
  css_logo = f"""
316
  <style>
 
327
  }}
328
  </style>
329
  """
 
330
  st.markdown(css_logo, unsafe_allow_html=True)
331
 
332
+ # Load or create the retriever for the main chatbot context.
333
  if "retriever_fr" not in st.session_state:
334
  st.session_state["retriever_fr"] = None
335
 
 
362
  st.write("""Je suis là pour répondre à toutes vos questions concernant nos
363
  services, nos offres mobiles et Internet, ainsi que nos solutions adaptées à vos besoins (FR).""")
364
 
365
+ # Text input for user's question.
366
  user_query_fr = st.chat_input("Posez votre question ici (FR)...")
367
 
368
  if user_query_fr:
 
370
  st.warning("Veuillez d'abord créer ou charger la Vector Store (FR).")
371
  return
372
 
373
+ # Retrieve context from the main chatbot vector store.
374
  context_fr_list = retrieve_context_fr(st.session_state["retriever_fr"], user_query_fr, top_k=5)
375
 
376
  if context_fr_list:
377
  with st.spinner("Génération de la réponse..."):
378
+ # Run the LLM chain to generate a candidate answer.
379
+ response_fr = llm_chain_fr.run({
380
+ "context": "\n".join(context_fr_list),
381
+ "query": user_query_fr + "?"
382
+ })
383
+ # Remove any prompt markers.
384
+ response_fr = response_fr.split("[RÉPONSE]")[-1].strip()
385
+
386
  st.write("**Question :**")
387
  st.write(user_query_fr)
388
+ st.write("**Réponse générée par l'IA :**")
389
  st.write(response_fr)
390
+
391
+ # --- Classification step ---
392
+ with st.spinner("Classification de la réponse..."):
393
+ predicted_label = predict_class(response_fr)
394
+ st.write(f"**Classe prédite :** {predicted_label}")
395
+
396
+ # --- Retrieve final answer using the classification vector store ---
397
+ if predicted_label != "Autre":
398
+ # Build or load the classification vector store if not already in session_state.
399
+ if "class_retriever" not in st.session_state:
400
+ # Either create new or load existing
401
+ try:
402
+ # Attempt to load an existing vector store.
403
+ vectorstore_class = load_existing_classification_vectorstore("./chroma_db_class_fr")
404
+ except Exception:
405
+ # If not found, create it.
406
+ vectorstore_class = load_classification_vectorstore("./chroma_db_class_fr")
407
+ st.session_state["class_retriever"] = vectorstore_class.as_retriever(
408
+ search_type="mmr",
409
+ search_kwargs={"k": 1, "lambda_mult": 0.5}
410
+ )
411
+ # Retrieve the final answer with a metadata filter.
412
+ # (Assumes the underlying retriever supports a filter parameter.)
413
+ final_docs = st.session_state["class_retriever"].get_relevant_documents(
414
+ response_fr, filter={"Classe": predicted_label}
415
+ )
416
+ if final_docs:
417
+ final_answer = final_docs[0].page_content
418
+ else:
419
+ final_answer = response_fr # fallback if no document found
420
+ else:
421
+ final_answer = ("Je n'ai pas d'information précise à ce sujet. "
422
+ "Souhaitez-vous que je vous mette en contact avec un agent Inwi ?")
423
+
424
+ st.write("**Réponse finale :**")
425
+ st.write(final_answer)
426
  else:
427
  st.write("Aucun contexte trouvé pour cette question. Essayez autre chose.")
428
 
 
429
  if __name__ == "__main__":
430
  main()