ramy2018 commited on
Commit
97f073d
·
verified ·
1 Parent(s): 79b95ab

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +27 -19
rag_pipeline.py CHANGED
@@ -16,40 +16,48 @@ class RAGPipeline:
16
 
17
  self.index = None
18
  self.chunks = []
 
19
  self.chunk_embeddings = []
20
 
21
  print("[RAG] تم تحميل النماذج بنجاح.")
22
 
 
 
 
 
 
 
 
 
 
 
23
  def build_index(self, chunks, logs=None):
24
  self.chunks = chunks
 
25
  self.chunk_embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
26
- if logs is not None:
27
- logs.append(f"[RAG] تم بناء الفهرس بـ {self.chunk_embeddings.shape[0]} مقطع.")
28
  self.index = np.array(self.chunk_embeddings)
 
 
29
 
30
  def answer(self, question):
31
- # Step 1: استرجاع المقاطع الأكثر صلة
32
  question_embedding = self.embedder.encode([question], convert_to_numpy=True)
33
  similarities = np.dot(self.index, question_embedding.T).squeeze()
34
  top_idx = similarities.argsort()[-5:][::-1]
 
 
35
  sources = [self.chunks[i] for i in top_idx]
36
- context = " ".join(sources)
37
 
38
- # Step 2: تلخيص النص المسترجع
 
 
 
 
39
  try:
40
- summary_prompt = f"لخص النص التالي باللغة العربية:\n\n{context}"
41
- inputs = self.tokenizer(summary_prompt, return_tensors="pt", truncation=True, max_length=512)
42
- summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
43
- summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
44
  except Exception as e:
45
- summary = ""
46
- print(f"[RAG] خطأ في التلخيص: {e}")
47
-
48
- # Step 3: توليد الإجابة من الملخص أو من النص الأصلي
49
- qa_context = summary if summary else context
50
- qa_prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{qa_context}\n\nالسؤال: {question}\nالإجابة:"
51
- qa_inputs = self.tokenizer(qa_prompt, return_tensors="pt", truncation=True, max_length=512)
52
- answer_ids = self.model.generate(qa_inputs["input_ids"], max_length=200)
53
- answer = self.tokenizer.decode(answer_ids[0], skip_special_tokens=True).strip()
54
 
55
- return answer, sources, summary
 
16
 
17
  self.index = None
18
  self.chunks = []
19
+ self.summaries = []
20
  self.chunk_embeddings = []
21
 
22
  print("[RAG] تم تحميل النماذج بنجاح.")
23
 
24
+ def summarize_text(self, text):
25
+ prompt = f"لخص النص التالي باللغة العربية:\n\n{text}"
26
+ try:
27
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
28
+ summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
29
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
30
+ except Exception as e:
31
+ print(f"[RAG] خطأ في التلخيص: {e}")
32
+ return ""
33
+
34
  def build_index(self, chunks, logs=None):
35
  self.chunks = chunks
36
+ self.summaries = [self.summarize_text(chunk) for chunk in chunks]
37
  self.chunk_embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
 
 
38
  self.index = np.array(self.chunk_embeddings)
39
+ if logs is not None:
40
+ logs.append(f"[RAG] تم بناء الفهرس بـ {len(self.chunk_embeddings)} مقطع.")
41
 
42
  def answer(self, question):
 
43
  question_embedding = self.embedder.encode([question], convert_to_numpy=True)
44
  similarities = np.dot(self.index, question_embedding.T).squeeze()
45
  top_idx = similarities.argsort()[-5:][::-1]
46
+
47
+ relevant_summaries = [self.summaries[i] for i in top_idx if self.summaries[i].strip()]
48
  sources = [self.chunks[i] for i in top_idx]
 
49
 
50
+ combined_summary = " ".join(relevant_summaries).strip()
51
+ if not combined_summary:
52
+ combined_summary = " ".join(sources)
53
+
54
+ qa_prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{combined_summary}\n\nالسؤال: {question}\nالإجابة:"
55
  try:
56
+ inputs = self.tokenizer(qa_prompt, return_tensors="pt", truncation=True, max_length=512)
57
+ output_ids = self.model.generate(inputs["input_ids"], max_length=200)
58
+ answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
 
59
  except Exception as e:
60
+ print(f"[RAG] خطأ في توليد الإجابة: {e}")
61
+ answer = ""
 
 
 
 
 
 
 
62
 
63
+ return answer, sources, combined_summary