|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from sentence_transformers import SentenceTransformer, models |
|
import numpy as np |
|
import torch |
|
|
|
class RAGPipeline: |
|
def __init__(self): |
|
print("[RAG] تحميل النماذج...") |
|
|
|
word_embedding_model = models.Transformer('asafaya/bert-base-arabic') |
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
|
self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model]) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum") |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum") |
|
|
|
self.chunks = [] |
|
self.embeddings = None |
|
|
|
print("[RAG] تم تحميل النماذج بنجاح.") |
|
|
|
def build_index(self, chunks): |
|
self.chunks = chunks |
|
self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True) |
|
|
|
def retrieve_passages(self, question, top_k=5): |
|
if self.embeddings is None or len(self.chunks) == 0: |
|
return [] |
|
question_embedding = self.embedder.encode([question], convert_to_numpy=True) |
|
similarities = np.dot(self.embeddings, question_embedding.T).squeeze() |
|
top_indices = similarities.argsort()[-top_k:][::-1] |
|
return [self.chunks[i] for i in top_indices] |
|
|
|
def summarize_text(self, text): |
|
print("[RAG][INPUT TO SUMMARIZE]:", text) |
|
prompt = f"summarize: {text}" |
|
try: |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
summary_ids = self.model.generate(inputs["input_ids"], max_length=128) |
|
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip() |
|
print(f"[RAG][DEBUG] الملخص الناتج:\n{summary}") |
|
return summary |
|
except Exception as e: |
|
print(f"[RAG][ERROR] أثناء التلخيص: {e}") |
|
return "" |
|
|
|
def generate_answer_from_passages(self, question, passages_text): |
|
summary = self.summarize_text(passages_text) |
|
|
|
prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{summary}\n\nالسؤال: {question}\nالإجابة:" |
|
try: |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
output_ids = self.model.generate(inputs["input_ids"], max_length=200) |
|
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
|
except Exception as e: |
|
print(f"[RAG][ERROR] أثناء توليد الإجابة: {e}") |
|
answer = "" |
|
|
|
return answer, summary |
|
|