Trabis commited on
Commit
9f620cb
1 Parent(s): 3b157eb

Update RAG_GRADIO.py

Browse files
Files changed (1) hide show
  1. RAG_GRADIO.py +328 -335
RAG_GRADIO.py CHANGED
@@ -1,336 +1,329 @@
1
- import gradio as gr
2
- from langchain_mistralai.chat_models import ChatMistralAI
3
- from langchain.prompts import ChatPromptTemplate
4
- import os
5
- from pathlib import Path
6
- from typing import List, Dict, Optional
7
- import json
8
- import faiss
9
- import numpy as np
10
- from langchain.schema import Document
11
- from sentence_transformers import SentenceTransformer
12
- import pickle
13
- import re
14
-
15
- class RAGLoader:
16
- def __init__(self,
17
- docs_folder: str = "./docs",
18
- splits_folder: str = "./splits",
19
- index_folder: str = "./index",
20
- model_name: str = "intfloat/multilingual-e5-large"):
21
- """
22
- Initialise le RAG Loader
23
-
24
- Args:
25
- docs_folder: Dossier contenant les documents sources
26
- splits_folder: Dossier où seront stockés les morceaux de texte
27
- index_folder: Dossier où sera stocké l'index FAISS
28
- model_name: Nom du modèle SentenceTransformer à utiliser
29
- """
30
- self.docs_folder = Path(docs_folder)
31
- self.splits_folder = Path(splits_folder)
32
- self.index_folder = Path(index_folder)
33
- self.model_name = model_name
34
-
35
- # Créer les dossiers s'ils n'existent pas
36
- self.splits_folder.mkdir(parents=True, exist_ok=True)
37
- self.index_folder.mkdir(parents=True, exist_ok=True)
38
-
39
- # Chemins des fichiers
40
- self.splits_path = self.splits_folder / "splits.json"
41
- self.index_path = self.index_folder / "faiss.index"
42
- self.documents_path = self.index_folder / "documents.pkl"
43
-
44
- # Initialiser le modèle
45
- self.model = None
46
- self.index = None
47
- self.indexed_documents = None
48
-
49
- def load_and_split_texts(self) -> List[Document]:
50
- """
51
- Charge les textes du dossier docs, les découpe en morceaux et les sauvegarde
52
- dans un fichier JSON unique.
53
-
54
- Returns:
55
- Liste de Documents contenant les morceaux de texte et leurs métadonnées
56
- """
57
- documents = []
58
-
59
- # Vérifier d'abord si les splits existent déjà
60
- if self._splits_exist():
61
- print("Chargement des splits existants...")
62
- return self._load_existing_splits()
63
-
64
- print("Création de nouveaux splits...")
65
- # Parcourir tous les fichiers du dossier docs
66
- for file_path in self.docs_folder.glob("*.txt"):
67
- with open(file_path, 'r', encoding='utf-8') as file:
68
- text = file.read()
69
-
70
- # Découper le texte en phrases
71
- # chunks = [chunk.strip() for chunk in text.split('.') if chunk.strip()]
72
- chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
73
-
74
- # Créer un Document pour chaque morceau
75
- for i, chunk in enumerate(chunks):
76
- doc = Document(
77
- page_content=chunk,
78
- metadata={
79
- 'source': file_path.name,
80
- 'chunk_id': i,
81
- 'total_chunks': len(chunks)
82
- }
83
- )
84
- documents.append(doc)
85
-
86
- # Sauvegarder tous les splits dans un seul fichier JSON
87
- self._save_splits(documents)
88
-
89
- print(f"Nombre total de morceaux créés: {len(documents)}")
90
- return documents
91
-
92
- def _splits_exist(self) -> bool:
93
- """Vérifie si le fichier de splits existe"""
94
- return self.splits_path.exists()
95
-
96
- def _save_splits(self, documents: List[Document]):
97
- """Sauvegarde tous les documents découpés dans un seul fichier JSON"""
98
- splits_data = {
99
- 'splits': [
100
- {
101
- 'text': doc.page_content,
102
- 'metadata': doc.metadata
103
- }
104
- for doc in documents
105
- ]
106
- }
107
-
108
- with open(self.splits_path, 'w', encoding='utf-8') as f:
109
- json.dump(splits_data, f, ensure_ascii=False, indent=2)
110
-
111
- def _load_existing_splits(self) -> List[Document]:
112
- """Charge les splits depuis le fichier JSON unique"""
113
- with open(self.splits_path, 'r', encoding='utf-8') as f:
114
- splits_data = json.load(f)
115
-
116
- documents = [
117
- Document(
118
- page_content=split['text'],
119
- metadata=split['metadata']
120
- )
121
- for split in splits_data['splits']
122
- ]
123
-
124
- print(f"Nombre de splits chargés: {len(documents)}")
125
- return documents
126
-
127
- def load_index(self) -> bool:
128
- """
129
- Charge l'index FAISS et les documents associés s'ils existent
130
-
131
- Returns:
132
- bool: True si l'index a été chargé, False sinon
133
- """
134
- if not self._index_exists():
135
- print("Aucun index trouvé.")
136
- return False
137
-
138
- print("Chargement de l'index existant...")
139
- try:
140
- # Charger l'index FAISS
141
- self.index = faiss.read_index(str(self.index_path))
142
-
143
- # Charger les documents associés
144
- with open(self.documents_path, 'rb') as f:
145
- self.indexed_documents = pickle.load(f)
146
-
147
- print(f"Index chargé avec {self.index.ntotal} vecteurs")
148
- return True
149
-
150
- except Exception as e:
151
- print(f"Erreur lors du chargement de l'index: {e}")
152
- return False
153
-
154
- def create_index(self, documents: Optional[List[Document]] = None) -> bool:
155
- """
156
- Crée un nouvel index FAISS à partir des documents.
157
- Si aucun document n'est fourni, charge les documents depuis le fichier JSON.
158
-
159
- Args:
160
- documents: Liste optionnelle de Documents à indexer
161
-
162
- Returns:
163
- bool: True si l'index a été créé avec succès, False sinon
164
- """
165
- try:
166
- # Initialiser le modèle si nécessaire
167
- if self.model is None:
168
- print("Chargement du modèle...")
169
- self.model = SentenceTransformer(self.model_name)
170
-
171
- # Charger les documents si non fournis
172
- if documents is None:
173
- documents = self.load_and_split_texts()
174
-
175
- if not documents:
176
- print("Aucun document à indexer.")
177
- return False
178
-
179
- print("Création des embeddings...")
180
- texts = [doc.page_content for doc in documents]
181
- embeddings = self.model.encode(texts, show_progress_bar=True)
182
-
183
- # Initialiser l'index FAISS
184
- dimension = embeddings.shape[1]
185
- self.index = faiss.IndexFlatL2(dimension)
186
-
187
- # Ajouter les vecteurs à l'index
188
- self.index.add(np.array(embeddings).astype('float32'))
189
-
190
- # Sauvegarder l'index
191
- print("Sauvegarde de l'index...")
192
- faiss.write_index(self.index, str(self.index_path))
193
-
194
- # Sauvegarder les documents associés
195
- self.indexed_documents = documents
196
- with open(self.documents_path, 'wb') as f:
197
- pickle.dump(documents, f)
198
-
199
- print(f"Index créé avec succès : {self.index.ntotal} vecteurs")
200
- return True
201
-
202
- except Exception as e:
203
- print(f"Erreur lors de la création de l'index: {e}")
204
- return False
205
-
206
- def _index_exists(self) -> bool:
207
- """Vérifie si l'index et les documents associés existent"""
208
- return self.index_path.exists() and self.documents_path.exists()
209
-
210
- def get_retriever(self, k: int = 5):
211
- """
212
- Crée un retriever pour l'utilisation avec LangChain
213
-
214
- Args:
215
- k: Nombre de documents similaires à retourner
216
-
217
- Returns:
218
- Callable: Fonction de recherche compatible avec LangChain
219
- """
220
- if self.index is None:
221
- if not self.load_index():
222
- if not self.create_index():
223
- raise ValueError("Impossible de charger ou créer l'index")
224
-
225
- if self.model is None:
226
- self.model = SentenceTransformer(self.model_name)
227
-
228
- def retriever_function(query: str) -> List[Document]:
229
- # Créer l'embedding de la requête
230
- query_embedding = self.model.encode([query])[0]
231
-
232
- # Rechercher les documents similaires
233
- distances, indices = self.index.search(
234
- np.array([query_embedding]).astype('float32'),
235
- k
236
- )
237
-
238
- # Retourner les documents trouvés
239
- results = []
240
- for idx in indices[0]:
241
- if idx != -1: # FAISS retourne -1 pour les résultats invalides
242
- results.append(self.indexed_documents[idx])
243
-
244
- return results
245
-
246
- return retriever_function
247
-
248
- # Initialize the RAG system
249
- llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key="QK0ZZpSxQbCEVgOLtI6FARQVmBYc6WGP")
250
- rag_loader = RAGLoader()
251
- retriever = rag_loader.get_retriever(k=5)
252
-
253
- prompt_template = ChatPromptTemplate.from_messages([
254
- ("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة.
255
- استخدم المعلومات التالية للإجابة على السؤال:
256
-
257
- {context}
258
-
259
- إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك.
260
- أجب بشكل موجز ودقيق."""),
261
- ("human", "{question}")
262
- ])
263
-
264
- def process_question(question: str) -> tuple[str, str]:
265
- """
266
- Process a question and return both the answer and the relevant context
267
- """
268
- relevant_docs = retriever(question)
269
- context = "\n".join([doc.page_content for doc in relevant_docs])
270
-
271
- prompt = prompt_template.format_messages(
272
- context=context,
273
- question=question
274
- )
275
-
276
- response = llm(prompt)
277
- return response.content, context
278
-
279
- def gradio_interface(question: str) -> tuple[str, str]:
280
- """
281
- Gradio interface function that returns both answer and context as a tuple.
282
- """
283
- # Replace with your actual function to process the question
284
- return process_question(question)
285
-
286
- # Custom CSS for right-aligned and RTL text
287
- custom_css = """
288
- #question-box textarea, #answer-box textarea, #context-box textarea {
289
- text-align: right !important;
290
- direction: rtl !important;
291
- }
292
- """
293
-
294
- # Test question
295
- question = "هل يجوز لرجل السلطة اقتناء عقار داخل مجال عمله"
296
- answer, context = process_question(question) # Ensure `process_question` is defined
297
-
298
- # Print results for testing
299
- print("الإجابة:", answer)
300
- print("\nالسياق المستخدم:", context)
301
-
302
- # Define the Gradio interface with custom CSS
303
- with gr.Blocks(css=custom_css) as iface:
304
- with gr.Column():
305
- input_text = gr.Textbox(
306
- label="السؤال",
307
- placeholder="اكتب سؤالك هنا...",
308
- lines=2,
309
- elem_id="question-box"
310
- )
311
-
312
- answer_box = gr.Textbox(
313
- label="الإجابة",
314
- lines=4,
315
- elem_id="answer-box"
316
- )
317
-
318
- context_box = gr.Textbox(
319
- label="السياق المستخدم",
320
- lines=8,
321
- elem_id="context-box"
322
- )
323
-
324
- submit_btn = gr.Button("إرسال")
325
-
326
- # Link submit button to processing function
327
- submit_btn.click(
328
- fn=gradio_interface,
329
- inputs=input_text,
330
- outputs=[answer_box, context_box]
331
- )
332
-
333
-
334
- # Launch the interface
335
- if __name__ == "__main__":
336
  iface.launch(share=True)
 
1
+ import gradio as gr
2
+ from langchain_mistralai.chat_models import ChatMistralAI
3
+ from langchain.prompts import ChatPromptTemplate
4
+ import os
5
+ from pathlib import Path
6
+ from typing import List, Dict, Optional
7
+ import json
8
+ import faiss
9
+ import numpy as np
10
+ from langchain.schema import Document
11
+ from sentence_transformers import SentenceTransformer
12
+ import pickle
13
+ import re
14
+
15
+ class RAGLoader:
16
+ def __init__(self,
17
+ docs_folder: str = "./docs",
18
+ splits_folder: str = "./splits",
19
+ index_folder: str = "./index",
20
+ model_name: str = "intfloat/multilingual-e5-large"):
21
+ """
22
+ Initialise le RAG Loader
23
+
24
+ Args:
25
+ docs_folder: Dossier contenant les documents sources
26
+ splits_folder: Dossier où seront stockés les morceaux de texte
27
+ index_folder: Dossier où sera stocké l'index FAISS
28
+ model_name: Nom du modèle SentenceTransformer à utiliser
29
+ """
30
+ self.docs_folder = Path(docs_folder)
31
+ self.splits_folder = Path(splits_folder)
32
+ self.index_folder = Path(index_folder)
33
+ self.model_name = model_name
34
+
35
+ # Créer les dossiers s'ils n'existent pas
36
+ self.splits_folder.mkdir(parents=True, exist_ok=True)
37
+ self.index_folder.mkdir(parents=True, exist_ok=True)
38
+
39
+ # Chemins des fichiers
40
+ self.splits_path = self.splits_folder / "splits.json"
41
+ self.index_path = self.index_folder / "faiss.index"
42
+ self.documents_path = self.index_folder / "documents.pkl"
43
+
44
+ # Initialiser le modèle
45
+ self.model = None
46
+ self.index = None
47
+ self.indexed_documents = None
48
+
49
+ def load_and_split_texts(self) -> List[Document]:
50
+ """
51
+ Charge les textes du dossier docs, les découpe en morceaux et les sauvegarde
52
+ dans un fichier JSON unique.
53
+
54
+ Returns:
55
+ Liste de Documents contenant les morceaux de texte et leurs métadonnées
56
+ """
57
+ documents = []
58
+
59
+ # Vérifier d'abord si les splits existent déjà
60
+ if self._splits_exist():
61
+ print("Chargement des splits existants...")
62
+ return self._load_existing_splits()
63
+
64
+ print("Création de nouveaux splits...")
65
+ # Parcourir tous les fichiers du dossier docs
66
+ for file_path in self.docs_folder.glob("*.txt"):
67
+ with open(file_path, 'r', encoding='utf-8') as file:
68
+ text = file.read()
69
+
70
+ # Découper le texte en phrases
71
+ # chunks = [chunk.strip() for chunk in text.split('.') if chunk.strip()]
72
+ chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
73
+
74
+ # Créer un Document pour chaque morceau
75
+ for i, chunk in enumerate(chunks):
76
+ doc = Document(
77
+ page_content=chunk,
78
+ metadata={
79
+ 'source': file_path.name,
80
+ 'chunk_id': i,
81
+ 'total_chunks': len(chunks)
82
+ }
83
+ )
84
+ documents.append(doc)
85
+
86
+ # Sauvegarder tous les splits dans un seul fichier JSON
87
+ self._save_splits(documents)
88
+
89
+ print(f"Nombre total de morceaux créés: {len(documents)}")
90
+ return documents
91
+
92
+ def _splits_exist(self) -> bool:
93
+ """Vérifie si le fichier de splits existe"""
94
+ return self.splits_path.exists()
95
+
96
+ def _save_splits(self, documents: List[Document]):
97
+ """Sauvegarde tous les documents découpés dans un seul fichier JSON"""
98
+ splits_data = {
99
+ 'splits': [
100
+ {
101
+ 'text': doc.page_content,
102
+ 'metadata': doc.metadata
103
+ }
104
+ for doc in documents
105
+ ]
106
+ }
107
+
108
+ with open(self.splits_path, 'w', encoding='utf-8') as f:
109
+ json.dump(splits_data, f, ensure_ascii=False, indent=2)
110
+
111
+ def _load_existing_splits(self) -> List[Document]:
112
+ """Charge les splits depuis le fichier JSON unique"""
113
+ with open(self.splits_path, 'r', encoding='utf-8') as f:
114
+ splits_data = json.load(f)
115
+
116
+ documents = [
117
+ Document(
118
+ page_content=split['text'],
119
+ metadata=split['metadata']
120
+ )
121
+ for split in splits_data['splits']
122
+ ]
123
+
124
+ print(f"Nombre de splits chargés: {len(documents)}")
125
+ return documents
126
+
127
+ def load_index(self) -> bool:
128
+ """
129
+ Charge l'index FAISS et les documents associés s'ils existent
130
+
131
+ Returns:
132
+ bool: True si l'index a été chargé, False sinon
133
+ """
134
+ if not self._index_exists():
135
+ print("Aucun index trouvé.")
136
+ return False
137
+
138
+ print("Chargement de l'index existant...")
139
+ try:
140
+ # Charger l'index FAISS
141
+ self.index = faiss.read_index(str(self.index_path))
142
+
143
+ # Charger les documents associés
144
+ with open(self.documents_path, 'rb') as f:
145
+ self.indexed_documents = pickle.load(f)
146
+
147
+ print(f"Index chargé avec {self.index.ntotal} vecteurs")
148
+ return True
149
+
150
+ except Exception as e:
151
+ print(f"Erreur lors du chargement de l'index: {e}")
152
+ return False
153
+
154
+ def create_index(self, documents: Optional[List[Document]] = None) -> bool:
155
+ """
156
+ Crée un nouvel index FAISS à partir des documents.
157
+ Si aucun document n'est fourni, charge les documents depuis le fichier JSON.
158
+
159
+ Args:
160
+ documents: Liste optionnelle de Documents à indexer
161
+
162
+ Returns:
163
+ bool: True si l'index a été créé avec succès, False sinon
164
+ """
165
+ try:
166
+ # Initialiser le modèle si nécessaire
167
+ if self.model is None:
168
+ print("Chargement du modèle...")
169
+ self.model = SentenceTransformer(self.model_name)
170
+
171
+ # Charger les documents si non fournis
172
+ if documents is None:
173
+ documents = self.load_and_split_texts()
174
+
175
+ if not documents:
176
+ print("Aucun document à indexer.")
177
+ return False
178
+
179
+ print("Création des embeddings...")
180
+ texts = [doc.page_content for doc in documents]
181
+ embeddings = self.model.encode(texts, show_progress_bar=True)
182
+
183
+ # Initialiser l'index FAISS
184
+ dimension = embeddings.shape[1]
185
+ self.index = faiss.IndexFlatL2(dimension)
186
+
187
+ # Ajouter les vecteurs à l'index
188
+ self.index.add(np.array(embeddings).astype('float32'))
189
+
190
+ # Sauvegarder l'index
191
+ print("Sauvegarde de l'index...")
192
+ faiss.write_index(self.index, str(self.index_path))
193
+
194
+ # Sauvegarder les documents associés
195
+ self.indexed_documents = documents
196
+ with open(self.documents_path, 'wb') as f:
197
+ pickle.dump(documents, f)
198
+
199
+ print(f"Index créé avec succès : {self.index.ntotal} vecteurs")
200
+ return True
201
+
202
+ except Exception as e:
203
+ print(f"Erreur lors de la création de l'index: {e}")
204
+ return False
205
+
206
+ def _index_exists(self) -> bool:
207
+ """Vérifie si l'index et les documents associés existent"""
208
+ return self.index_path.exists() and self.documents_path.exists()
209
+
210
+ def get_retriever(self, k: int = 5):
211
+ """
212
+ Crée un retriever pour l'utilisation avec LangChain
213
+
214
+ Args:
215
+ k: Nombre de documents similaires à retourner
216
+
217
+ Returns:
218
+ Callable: Fonction de recherche compatible avec LangChain
219
+ """
220
+ if self.index is None:
221
+ if not self.load_index():
222
+ if not self.create_index():
223
+ raise ValueError("Impossible de charger ou créer l'index")
224
+
225
+ if self.model is None:
226
+ self.model = SentenceTransformer(self.model_name)
227
+
228
+ def retriever_function(query: str) -> List[Document]:
229
+ # Créer l'embedding de la requête
230
+ query_embedding = self.model.encode([query])[0]
231
+
232
+ # Rechercher les documents similaires
233
+ distances, indices = self.index.search(
234
+ np.array([query_embedding]).astype('float32'),
235
+ k
236
+ )
237
+
238
+ # Retourner les documents trouvés
239
+ results = []
240
+ for idx in indices[0]:
241
+ if idx != -1: # FAISS retourne -1 pour les résultats invalides
242
+ results.append(self.indexed_documents[idx])
243
+
244
+ return results
245
+
246
+ return retriever_function
247
+
248
+ # Initialize the RAG system
249
+ llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key="QK0ZZpSxQbCEVgOLtI6FARQVmBYc6WGP")
250
+ rag_loader = RAGLoader()
251
+ retriever = rag_loader.get_retriever(k=5)
252
+
253
+ prompt_template = ChatPromptTemplate.from_messages([
254
+ ("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة.
255
+ استخدم المعلومات التالية للإجابة على السؤال:
256
+
257
+ {context}
258
+
259
+ إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك.
260
+ أجب بشكل موجز ودقيق."""),
261
+ ("human", "{question}")
262
+ ])
263
+
264
+ def process_question(question: str) -> tuple[str, str]:
265
+ """
266
+ Process a question and return both the answer and the relevant context
267
+ """
268
+ relevant_docs = retriever(question)
269
+ context = "\n".join([doc.page_content for doc in relevant_docs])
270
+
271
+ prompt = prompt_template.format_messages(
272
+ context=context,
273
+ question=question
274
+ )
275
+
276
+ response = llm(prompt)
277
+ return response.content, context
278
+
279
+ def gradio_interface(question: str) -> tuple[str, str]:
280
+ """
281
+ Gradio interface function that returns both answer and context as a tuple
282
+ """
283
+ return process_question(question)
284
+
285
+ # Custom CSS for right-aligned text in textboxes
286
+ custom_css = """
287
+ .rtl-text {
288
+ text-align: right !important;
289
+ direction: rtl !important;
290
+ }
291
+ .rtl-text textarea {
292
+ text-align: right !important;
293
+ direction: rtl !important;
294
+ }
295
+ """
296
+
297
+ # Define the Gradio interface
298
+ with gr.Blocks(css=custom_css) as iface:
299
+ with gr.Column():
300
+ input_text = gr.Textbox(
301
+ label="السؤال",
302
+ placeholder="اكتب سؤالك هنا...",
303
+ lines=2,
304
+ elem_classes="rtl-text"
305
+ )
306
+
307
+ answer_box = gr.Textbox(
308
+ label="الإجابة",
309
+ lines=4,
310
+ elem_classes="rtl-text"
311
+ )
312
+
313
+ context_box = gr.Textbox(
314
+ label="السياق المستخدم",
315
+ lines=8,
316
+ elem_classes="rtl-text"
317
+ )
318
+
319
+ submit_btn = gr.Button("إرسال")
320
+
321
+ submit_btn.click(
322
+ fn=gradio_interface,
323
+ inputs=input_text,
324
+ outputs=[answer_box, context_box]
325
+ )
326
+
327
+ # Launch the interface
328
+ if __name__ == "__main__":
 
 
 
 
 
 
 
329
  iface.launch(share=True)