ramy2018 commited on
Commit
d854811
·
verified ·
1 Parent(s): 47cd07d

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +26 -42
rag_pipeline.py CHANGED
@@ -14,64 +14,48 @@ class RAGPipeline:
14
  self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
16
 
17
- self.index = None
18
  self.chunks = []
19
- self.chunk_embeddings = []
20
- self.summaries = []
21
 
22
  print("[RAG] تم تحميل النماذج بنجاح.")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def summarize_text(self, text):
25
- prompt = f"لخص النص التالي باللغة العربية:\n\n{text}"
 
26
  try:
27
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
28
  summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
29
  return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
30
- except Exception as e:
31
- print(f"[RAG] خطأ في التلخيص: {e}")
32
  return ""
33
 
34
- def build_index(self, chunks, logs=None):
35
- self.chunks = chunks
36
- self.chunk_embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
37
- self.index = np.array(self.chunk_embeddings)
38
- self.summaries = []
39
- if logs is not None:
40
- logs.append(f"[RAG] تم بناء الفهرس لـ {len(self.chunk_embeddings)} مقطع.")
41
-
42
- def summarize_all_chunks(self, max_chunks=20):
43
- self.summaries = []
44
- total = min(max_chunks, len(self.chunks))
45
- print(f"[RAG] تلخيص {total} من {len(self.chunks)} مقطع...")
46
- for i, chunk in enumerate(self.chunks[:total]):
47
- print(f"[RAG] تلخيص المقطع {i+1}/{total}")
48
- summary = self.summarize_text(chunk)
49
- self.summaries.append(summary)
50
-
51
- def answer(self, question):
52
- question_embedding = self.embedder.encode([question], convert_to_numpy=True)
53
- similarities = np.dot(self.index, question_embedding.T).squeeze()
54
- top_idx = similarities.argsort()[-5:][::-1]
55
 
56
- sources = [self.chunks[i] for i in top_idx]
57
 
58
- relevant_summaries = [
59
- self.summaries[i]
60
- for i in top_idx
61
- if i < len(self.summaries) and self.summaries[i].strip()
62
- ]
63
- combined_summary = " ".join(relevant_summaries).strip()
64
 
65
- if not combined_summary:
66
- combined_summary = " ".join(sources)
67
-
68
- qa_prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{combined_summary}\n\nالسؤال: {question}\nالإجابة:"
69
  try:
70
- inputs = self.tokenizer(qa_prompt, return_tensors="pt", truncation=True, max_length=512)
71
  output_ids = self.model.generate(inputs["input_ids"], max_length=200)
72
  answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
73
- except Exception as e:
74
- print(f"[RAG] خطأ في توليد الإجابة: {e}")
75
  answer = ""
76
 
77
- return answer, sources, combined_summary
 
14
  self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
16
 
 
17
  self.chunks = []
18
+ self.embeddings = []
 
19
 
20
  print("[RAG] تم تحميل النماذج بنجاح.")
21
 
22
+ def build_index(self, chunks):
23
+ self.chunks = chunks
24
+ self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
25
+
26
+ def retrieve_passages(self, question, top_k=5):
27
+ if not self.embeddings or not self.chunks:
28
+ return []
29
+ question_embedding = self.embedder.encode([question], convert_to_numpy=True)
30
+ similarities = np.dot(self.embeddings, question_embedding.T).squeeze()
31
+ top_indices = similarities.argsort()[-top_k:][::-1]
32
+ return [self.chunks[i] for i in top_indices]
33
+
34
  def summarize_text(self, text):
35
+ prompt = f"لخص النص التالي:
36
+ {text}"
37
  try:
38
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
39
  summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
40
  return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
41
+ except:
 
42
  return ""
43
 
44
+ def generate_answer_from_passages(self, question, passages):
45
+ context = " ".join(passages)
46
+ summary = self.summarize_text(context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ prompt = f"أجب عن السؤال التالي بناء على النص:
49
 
50
+ {summary}
 
 
 
 
 
51
 
52
+ السؤال: {question}
53
+ الإجابة:"
 
 
54
  try:
55
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
56
  output_ids = self.model.generate(inputs["input_ids"], max_length=200)
57
  answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
58
+ except:
 
59
  answer = ""
60
 
61
+ return answer, summary