Update rag_pipeline.py
Browse files- rag_pipeline.py +27 -19
rag_pipeline.py
CHANGED
@@ -16,40 +16,48 @@ class RAGPipeline:
|
|
16 |
|
17 |
self.index = None
|
18 |
self.chunks = []
|
|
|
19 |
self.chunk_embeddings = []
|
20 |
|
21 |
print("[RAG] تم تحميل النماذج بنجاح.")
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def build_index(self, chunks, logs=None):
|
24 |
self.chunks = chunks
|
|
|
25 |
self.chunk_embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
|
26 |
-
if logs is not None:
|
27 |
-
logs.append(f"[RAG] تم بناء الفهرس بـ {self.chunk_embeddings.shape[0]} مقطع.")
|
28 |
self.index = np.array(self.chunk_embeddings)
|
|
|
|
|
29 |
|
30 |
def answer(self, question):
|
31 |
-
# Step 1: استرجاع المقاطع الأكثر صلة
|
32 |
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
|
33 |
similarities = np.dot(self.index, question_embedding.T).squeeze()
|
34 |
top_idx = similarities.argsort()[-5:][::-1]
|
|
|
|
|
35 |
sources = [self.chunks[i] for i in top_idx]
|
36 |
-
context = " ".join(sources)
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
try:
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
44 |
except Exception as e:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
# Step 3: توليد الإجابة من الملخص أو من النص الأصلي
|
49 |
-
qa_context = summary if summary else context
|
50 |
-
qa_prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{qa_context}\n\nالسؤال: {question}\nالإجابة:"
|
51 |
-
qa_inputs = self.tokenizer(qa_prompt, return_tensors="pt", truncation=True, max_length=512)
|
52 |
-
answer_ids = self.model.generate(qa_inputs["input_ids"], max_length=200)
|
53 |
-
answer = self.tokenizer.decode(answer_ids[0], skip_special_tokens=True).strip()
|
54 |
|
55 |
-
return answer, sources,
|
|
|
16 |
|
17 |
self.index = None
|
18 |
self.chunks = []
|
19 |
+
self.summaries = []
|
20 |
self.chunk_embeddings = []
|
21 |
|
22 |
print("[RAG] تم تحميل النماذج بنجاح.")
|
23 |
|
24 |
+
def summarize_text(self, text):
|
25 |
+
prompt = f"لخص النص التالي باللغة العربية:\n\n{text}"
|
26 |
+
try:
|
27 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
28 |
+
summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
|
29 |
+
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
30 |
+
except Exception as e:
|
31 |
+
print(f"[RAG] خطأ في التلخيص: {e}")
|
32 |
+
return ""
|
33 |
+
|
34 |
def build_index(self, chunks, logs=None):
|
35 |
self.chunks = chunks
|
36 |
+
self.summaries = [self.summarize_text(chunk) for chunk in chunks]
|
37 |
self.chunk_embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
|
|
|
|
|
38 |
self.index = np.array(self.chunk_embeddings)
|
39 |
+
if logs is not None:
|
40 |
+
logs.append(f"[RAG] تم بناء الفهرس بـ {len(self.chunk_embeddings)} مقطع.")
|
41 |
|
42 |
def answer(self, question):
|
|
|
43 |
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
|
44 |
similarities = np.dot(self.index, question_embedding.T).squeeze()
|
45 |
top_idx = similarities.argsort()[-5:][::-1]
|
46 |
+
|
47 |
+
relevant_summaries = [self.summaries[i] for i in top_idx if self.summaries[i].strip()]
|
48 |
sources = [self.chunks[i] for i in top_idx]
|
|
|
49 |
|
50 |
+
combined_summary = " ".join(relevant_summaries).strip()
|
51 |
+
if not combined_summary:
|
52 |
+
combined_summary = " ".join(sources)
|
53 |
+
|
54 |
+
qa_prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{combined_summary}\n\nالسؤال: {question}\nالإجابة:"
|
55 |
try:
|
56 |
+
inputs = self.tokenizer(qa_prompt, return_tensors="pt", truncation=True, max_length=512)
|
57 |
+
output_ids = self.model.generate(inputs["input_ids"], max_length=200)
|
58 |
+
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
|
|
59 |
except Exception as e:
|
60 |
+
print(f"[RAG] خطأ في توليد الإجابة: {e}")
|
61 |
+
answer = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
return answer, sources, combined_summary
|