Update rag_pipeline.py
Browse files- rag_pipeline.py +26 -42
rag_pipeline.py
CHANGED
@@ -14,64 +14,48 @@ class RAGPipeline:
|
|
14 |
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
15 |
self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
16 |
|
17 |
-
self.index = None
|
18 |
self.chunks = []
|
19 |
-
self.
|
20 |
-
self.summaries = []
|
21 |
|
22 |
print("[RAG] تم تحميل النماذج بنجاح.")
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def summarize_text(self, text):
|
25 |
-
prompt = f"لخص النص
|
|
|
26 |
try:
|
27 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
28 |
summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
|
29 |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
30 |
-
except
|
31 |
-
print(f"[RAG] خطأ في التلخيص: {e}")
|
32 |
return ""
|
33 |
|
34 |
-
def
|
35 |
-
|
36 |
-
|
37 |
-
self.index = np.array(self.chunk_embeddings)
|
38 |
-
self.summaries = []
|
39 |
-
if logs is not None:
|
40 |
-
logs.append(f"[RAG] تم بناء الفهرس لـ {len(self.chunk_embeddings)} مقطع.")
|
41 |
-
|
42 |
-
def summarize_all_chunks(self, max_chunks=20):
|
43 |
-
self.summaries = []
|
44 |
-
total = min(max_chunks, len(self.chunks))
|
45 |
-
print(f"[RAG] تلخيص {total} من {len(self.chunks)} مقطع...")
|
46 |
-
for i, chunk in enumerate(self.chunks[:total]):
|
47 |
-
print(f"[RAG] تلخيص المقطع {i+1}/{total}")
|
48 |
-
summary = self.summarize_text(chunk)
|
49 |
-
self.summaries.append(summary)
|
50 |
-
|
51 |
-
def answer(self, question):
|
52 |
-
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
|
53 |
-
similarities = np.dot(self.index, question_embedding.T).squeeze()
|
54 |
-
top_idx = similarities.argsort()[-5:][::-1]
|
55 |
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
self.summaries[i]
|
60 |
-
for i in top_idx
|
61 |
-
if i < len(self.summaries) and self.summaries[i].strip()
|
62 |
-
]
|
63 |
-
combined_summary = " ".join(relevant_summaries).strip()
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
qa_prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{combined_summary}\n\nالسؤال: {question}\nالإجابة:"
|
69 |
try:
|
70 |
-
inputs = self.tokenizer(
|
71 |
output_ids = self.model.generate(inputs["input_ids"], max_length=200)
|
72 |
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
73 |
-
except
|
74 |
-
print(f"[RAG] خطأ في توليد الإجابة: {e}")
|
75 |
answer = ""
|
76 |
|
77 |
-
return answer,
|
|
|
14 |
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
15 |
self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
16 |
|
|
|
17 |
self.chunks = []
|
18 |
+
self.embeddings = []
|
|
|
19 |
|
20 |
print("[RAG] تم تحميل النماذج بنجاح.")
|
21 |
|
22 |
+
def build_index(self, chunks):
|
23 |
+
self.chunks = chunks
|
24 |
+
self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
|
25 |
+
|
26 |
+
def retrieve_passages(self, question, top_k=5):
|
27 |
+
if not self.embeddings or not self.chunks:
|
28 |
+
return []
|
29 |
+
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
|
30 |
+
similarities = np.dot(self.embeddings, question_embedding.T).squeeze()
|
31 |
+
top_indices = similarities.argsort()[-top_k:][::-1]
|
32 |
+
return [self.chunks[i] for i in top_indices]
|
33 |
+
|
34 |
def summarize_text(self, text):
|
35 |
+
prompt = f"لخص النص التالي:
|
36 |
+
{text}"
|
37 |
try:
|
38 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
39 |
summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
|
40 |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
41 |
+
except:
|
|
|
42 |
return ""
|
43 |
|
44 |
+
def generate_answer_from_passages(self, question, passages):
|
45 |
+
context = " ".join(passages)
|
46 |
+
summary = self.summarize_text(context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
prompt = f"أجب عن السؤال التالي بناء على النص:
|
49 |
|
50 |
+
{summary}
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
السؤال: {question}
|
53 |
+
الإجابة:"
|
|
|
|
|
54 |
try:
|
55 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
56 |
output_ids = self.model.generate(inputs["input_ids"], max_length=200)
|
57 |
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
58 |
+
except:
|
|
|
59 |
answer = ""
|
60 |
|
61 |
+
return answer, summary
|