Trabis commited on
Commit
56f8c62
1 Parent(s): 84ef393

Upload 2 files

Browse files
Files changed (2) hide show
  1. RAG_GRADIO.py +336 -0
  2. requirements.txt +90 -0
RAG_GRADIO.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_mistralai.chat_models import ChatMistralAI
3
+ from langchain.prompts import ChatPromptTemplate
4
+ import os
5
+ from pathlib import Path
6
+ from typing import List, Dict, Optional
7
+ import json
8
+ import faiss
9
+ import numpy as np
10
+ from langchain.schema import Document
11
+ from sentence_transformers import SentenceTransformer
12
+ import pickle
13
+ import re
14
+
15
+ class RAGLoader:
16
+ def __init__(self,
17
+ docs_folder: str = "./docs",
18
+ splits_folder: str = "./splits",
19
+ index_folder: str = "./index",
20
+ model_name: str = "intfloat/multilingual-e5-large"):
21
+ """
22
+ Initialise le RAG Loader
23
+
24
+ Args:
25
+ docs_folder: Dossier contenant les documents sources
26
+ splits_folder: Dossier où seront stockés les morceaux de texte
27
+ index_folder: Dossier où sera stocké l'index FAISS
28
+ model_name: Nom du modèle SentenceTransformer à utiliser
29
+ """
30
+ self.docs_folder = Path(docs_folder)
31
+ self.splits_folder = Path(splits_folder)
32
+ self.index_folder = Path(index_folder)
33
+ self.model_name = model_name
34
+
35
+ # Créer les dossiers s'ils n'existent pas
36
+ self.splits_folder.mkdir(parents=True, exist_ok=True)
37
+ self.index_folder.mkdir(parents=True, exist_ok=True)
38
+
39
+ # Chemins des fichiers
40
+ self.splits_path = self.splits_folder / "splits.json"
41
+ self.index_path = self.index_folder / "faiss.index"
42
+ self.documents_path = self.index_folder / "documents.pkl"
43
+
44
+ # Initialiser le modèle
45
+ self.model = None
46
+ self.index = None
47
+ self.indexed_documents = None
48
+
49
+ def load_and_split_texts(self) -> List[Document]:
50
+ """
51
+ Charge les textes du dossier docs, les découpe en morceaux et les sauvegarde
52
+ dans un fichier JSON unique.
53
+
54
+ Returns:
55
+ Liste de Documents contenant les morceaux de texte et leurs métadonnées
56
+ """
57
+ documents = []
58
+
59
+ # Vérifier d'abord si les splits existent déjà
60
+ if self._splits_exist():
61
+ print("Chargement des splits existants...")
62
+ return self._load_existing_splits()
63
+
64
+ print("Création de nouveaux splits...")
65
+ # Parcourir tous les fichiers du dossier docs
66
+ for file_path in self.docs_folder.glob("*.txt"):
67
+ with open(file_path, 'r', encoding='utf-8') as file:
68
+ text = file.read()
69
+
70
+ # Découper le texte en phrases
71
+ # chunks = [chunk.strip() for chunk in text.split('.') if chunk.strip()]
72
+ chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
73
+
74
+ # Créer un Document pour chaque morceau
75
+ for i, chunk in enumerate(chunks):
76
+ doc = Document(
77
+ page_content=chunk,
78
+ metadata={
79
+ 'source': file_path.name,
80
+ 'chunk_id': i,
81
+ 'total_chunks': len(chunks)
82
+ }
83
+ )
84
+ documents.append(doc)
85
+
86
+ # Sauvegarder tous les splits dans un seul fichier JSON
87
+ self._save_splits(documents)
88
+
89
+ print(f"Nombre total de morceaux créés: {len(documents)}")
90
+ return documents
91
+
92
+ def _splits_exist(self) -> bool:
93
+ """Vérifie si le fichier de splits existe"""
94
+ return self.splits_path.exists()
95
+
96
+ def _save_splits(self, documents: List[Document]):
97
+ """Sauvegarde tous les documents découpés dans un seul fichier JSON"""
98
+ splits_data = {
99
+ 'splits': [
100
+ {
101
+ 'text': doc.page_content,
102
+ 'metadata': doc.metadata
103
+ }
104
+ for doc in documents
105
+ ]
106
+ }
107
+
108
+ with open(self.splits_path, 'w', encoding='utf-8') as f:
109
+ json.dump(splits_data, f, ensure_ascii=False, indent=2)
110
+
111
+ def _load_existing_splits(self) -> List[Document]:
112
+ """Charge les splits depuis le fichier JSON unique"""
113
+ with open(self.splits_path, 'r', encoding='utf-8') as f:
114
+ splits_data = json.load(f)
115
+
116
+ documents = [
117
+ Document(
118
+ page_content=split['text'],
119
+ metadata=split['metadata']
120
+ )
121
+ for split in splits_data['splits']
122
+ ]
123
+
124
+ print(f"Nombre de splits chargés: {len(documents)}")
125
+ return documents
126
+
127
+ def load_index(self) -> bool:
128
+ """
129
+ Charge l'index FAISS et les documents associés s'ils existent
130
+
131
+ Returns:
132
+ bool: True si l'index a été chargé, False sinon
133
+ """
134
+ if not self._index_exists():
135
+ print("Aucun index trouvé.")
136
+ return False
137
+
138
+ print("Chargement de l'index existant...")
139
+ try:
140
+ # Charger l'index FAISS
141
+ self.index = faiss.read_index(str(self.index_path))
142
+
143
+ # Charger les documents associés
144
+ with open(self.documents_path, 'rb') as f:
145
+ self.indexed_documents = pickle.load(f)
146
+
147
+ print(f"Index chargé avec {self.index.ntotal} vecteurs")
148
+ return True
149
+
150
+ except Exception as e:
151
+ print(f"Erreur lors du chargement de l'index: {e}")
152
+ return False
153
+
154
+ def create_index(self, documents: Optional[List[Document]] = None) -> bool:
155
+ """
156
+ Crée un nouvel index FAISS à partir des documents.
157
+ Si aucun document n'est fourni, charge les documents depuis le fichier JSON.
158
+
159
+ Args:
160
+ documents: Liste optionnelle de Documents à indexer
161
+
162
+ Returns:
163
+ bool: True si l'index a été créé avec succès, False sinon
164
+ """
165
+ try:
166
+ # Initialiser le modèle si nécessaire
167
+ if self.model is None:
168
+ print("Chargement du modèle...")
169
+ self.model = SentenceTransformer(self.model_name)
170
+
171
+ # Charger les documents si non fournis
172
+ if documents is None:
173
+ documents = self.load_and_split_texts()
174
+
175
+ if not documents:
176
+ print("Aucun document à indexer.")
177
+ return False
178
+
179
+ print("Création des embeddings...")
180
+ texts = [doc.page_content for doc in documents]
181
+ embeddings = self.model.encode(texts, show_progress_bar=True)
182
+
183
+ # Initialiser l'index FAISS
184
+ dimension = embeddings.shape[1]
185
+ self.index = faiss.IndexFlatL2(dimension)
186
+
187
+ # Ajouter les vecteurs à l'index
188
+ self.index.add(np.array(embeddings).astype('float32'))
189
+
190
+ # Sauvegarder l'index
191
+ print("Sauvegarde de l'index...")
192
+ faiss.write_index(self.index, str(self.index_path))
193
+
194
+ # Sauvegarder les documents associés
195
+ self.indexed_documents = documents
196
+ with open(self.documents_path, 'wb') as f:
197
+ pickle.dump(documents, f)
198
+
199
+ print(f"Index créé avec succès : {self.index.ntotal} vecteurs")
200
+ return True
201
+
202
+ except Exception as e:
203
+ print(f"Erreur lors de la création de l'index: {e}")
204
+ return False
205
+
206
+ def _index_exists(self) -> bool:
207
+ """Vérifie si l'index et les documents associés existent"""
208
+ return self.index_path.exists() and self.documents_path.exists()
209
+
210
+ def get_retriever(self, k: int = 5):
211
+ """
212
+ Crée un retriever pour l'utilisation avec LangChain
213
+
214
+ Args:
215
+ k: Nombre de documents similaires à retourner
216
+
217
+ Returns:
218
+ Callable: Fonction de recherche compatible avec LangChain
219
+ """
220
+ if self.index is None:
221
+ if not self.load_index():
222
+ if not self.create_index():
223
+ raise ValueError("Impossible de charger ou créer l'index")
224
+
225
+ if self.model is None:
226
+ self.model = SentenceTransformer(self.model_name)
227
+
228
+ def retriever_function(query: str) -> List[Document]:
229
+ # Créer l'embedding de la requête
230
+ query_embedding = self.model.encode([query])[0]
231
+
232
+ # Rechercher les documents similaires
233
+ distances, indices = self.index.search(
234
+ np.array([query_embedding]).astype('float32'),
235
+ k
236
+ )
237
+
238
+ # Retourner les documents trouvés
239
+ results = []
240
+ for idx in indices[0]:
241
+ if idx != -1: # FAISS retourne -1 pour les résultats invalides
242
+ results.append(self.indexed_documents[idx])
243
+
244
+ return results
245
+
246
+ return retriever_function
247
+
248
+ # Initialize the RAG system
249
+ llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key="QK0ZZpSxQbCEVgOLtI6FARQVmBYc6WGP")
250
+ rag_loader = RAGLoader()
251
+ retriever = rag_loader.get_retriever(k=5)
252
+
253
+ prompt_template = ChatPromptTemplate.from_messages([
254
+ ("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة.
255
+ استخدم المعلومات التالية للإجابة على السؤال:
256
+
257
+ {context}
258
+
259
+ إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك.
260
+ أجب بشكل موجز ودقيق."""),
261
+ ("human", "{question}")
262
+ ])
263
+
264
+ def process_question(question: str) -> tuple[str, str]:
265
+ """
266
+ Process a question and return both the answer and the relevant context
267
+ """
268
+ relevant_docs = retriever(question)
269
+ context = "\n".join([doc.page_content for doc in relevant_docs])
270
+
271
+ prompt = prompt_template.format_messages(
272
+ context=context,
273
+ question=question
274
+ )
275
+
276
+ response = llm(prompt)
277
+ return response.content, context
278
+
279
+ def gradio_interface(question: str) -> tuple[str, str]:
280
+ """
281
+ Gradio interface function that returns both answer and context as a tuple.
282
+ """
283
+ # Replace with your actual function to process the question
284
+ return process_question(question)
285
+
286
+ # Custom CSS for right-aligned and RTL text
287
+ custom_css = """
288
+ #question-box textarea, #answer-box textarea, #context-box textarea {
289
+ text-align: right !important;
290
+ direction: rtl !important;
291
+ }
292
+ """
293
+
294
+ # Test question
295
+ question = "هل يجوز لرجل السلطة اقتناء عقار داخل مجال عمله"
296
+ answer, context = process_question(question) # Ensure `process_question` is defined
297
+
298
+ # Print results for testing
299
+ print("الإجابة:", answer)
300
+ print("\nالسياق المستخدم:", context)
301
+
302
+ # Define the Gradio interface with custom CSS
303
+ with gr.Blocks(css=custom_css) as iface:
304
+ with gr.Column():
305
+ input_text = gr.Textbox(
306
+ label="السؤال",
307
+ placeholder="اكتب سؤالك هنا...",
308
+ lines=2,
309
+ elem_id="question-box"
310
+ )
311
+
312
+ answer_box = gr.Textbox(
313
+ label="الإجابة",
314
+ lines=4,
315
+ elem_id="answer-box"
316
+ )
317
+
318
+ context_box = gr.Textbox(
319
+ label="السياق المستخدم",
320
+ lines=8,
321
+ elem_id="context-box"
322
+ )
323
+
324
+ submit_btn = gr.Button("إرسال")
325
+
326
+ # Link submit button to processing function
327
+ submit_btn.click(
328
+ fn=gradio_interface,
329
+ inputs=input_text,
330
+ outputs=[answer_box, context_box]
331
+ )
332
+
333
+
334
+ # Launch the interface
335
+ if __name__ == "__main__":
336
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.3
2
+ aiohttp==3.10.10
3
+ aiosignal==1.3.1
4
+ altair==5.4.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ attrs==24.2.0
8
+ blinker==1.8.2
9
+ cachetools==5.5.0
10
+ certifi==2024.7.4
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ colorama==0.4.6
14
+ distro==1.9.0
15
+ einops==0.8.0
16
+ faiss-cpu==1.9.0
17
+ filelock==3.16.1
18
+ frozenlist==1.4.1
19
+ fsspec==2024.9.0
20
+ gitdb==4.0.11
21
+ GitPython==3.1.43
22
+ greenlet==3.1.1
23
+ h11==0.14.0
24
+ httpcore==1.0.6
25
+ httpx==0.27.2
26
+ httpx-sse==0.4.0
27
+ huggingface-hub==0.26.0
28
+ idna==3.7
29
+ Jinja2==3.1.4
30
+ jiter==0.6.1
31
+ jsonpatch==1.33
32
+ jsonpointer==3.0.0
33
+ jsonschema==4.23.0
34
+ jsonschema-specifications==2024.10.1
35
+ langchain==0.3.4
36
+ langchain-core==0.3.12
37
+ langchain-mistralai==0.2.0
38
+ langchain-openai==0.2.3
39
+ langchain-text-splitters==0.3.0
40
+ langsmith==0.1.136
41
+ markdown-it-py==3.0.0
42
+ MarkupSafe==3.0.1
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ multidict==6.1.0
46
+ narwhals==1.9.4
47
+ networkx==3.4.2
48
+ numpy==1.26.4
49
+ openai==1.52.0
50
+ orjson==3.10.6
51
+ packaging==24.1
52
+ pandas==2.2.3
53
+ pillow==10.4.0
54
+ propcache==0.2.0
55
+ protobuf==5.28.2
56
+ pyarrow==17.0.0
57
+ pydantic==2.8.2
58
+ pydantic_core==2.20.1
59
+ pydeck==0.9.1
60
+ Pygments==2.18.0
61
+ python-dateutil==2.9.0.post0
62
+ pytz==2024.2
63
+ PyYAML==6.0.1
64
+ referencing==0.35.1
65
+ regex==2024.9.11
66
+ requests==2.32.3
67
+ requests-toolbelt==1.0.0
68
+ rich==13.9.2
69
+ rpds-py==0.20.0
70
+ safetensors==0.4.5
71
+ six==1.16.0
72
+ smmap==5.0.1
73
+ sniffio==1.3.1
74
+ SQLAlchemy==2.0.36
75
+ streamlit==1.39.0
76
+ streamlit_arabic_support_wrapper==1.1
77
+ sympy==1.13.1
78
+ tenacity==8.5.0
79
+ tiktoken==0.8.0
80
+ tokenizers==0.20.1
81
+ toml==0.10.2
82
+ torch==2.5.0
83
+ tornado==6.4.1
84
+ tqdm==4.66.5
85
+ transformers==4.45.2
86
+ typing_extensions==4.12.2
87
+ tzdata==2024.2
88
+ urllib3==2.2.2
89
+ watchdog==5.0.3
90
+ yarl==1.15.5