from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from sentence_transformers import SentenceTransformer, models import numpy as np import torch class RAGPipeline: def __init__(self): print("[RAG] تحميل النماذج...") word_embedding_model = models.Transformer('asafaya/bert-base-arabic') pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model]) # ✅ نموذج مخصص للتلخيص العربي self.tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum") self.model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum") self.chunks = [] self.embeddings = None print("[RAG] تم تحميل النماذج بنجاح.") def build_index(self, chunks): self.chunks = chunks self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True) def retrieve_passages(self, question, top_k=5): if self.embeddings is None or len(self.chunks) == 0: return [] question_embedding = self.embedder.encode([question], convert_to_numpy=True) similarities = np.dot(self.embeddings, question_embedding.T).squeeze() top_indices = similarities.argsort()[-top_k:][::-1] return [self.chunks[i] for i in top_indices] def summarize_text(self, text): print("[RAG][INPUT TO SUMMARIZE]:", text) prompt = f"summarize: {text}" try: inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) summary_ids = self.model.generate(inputs["input_ids"], max_length=128) summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip() print(f"[RAG][DEBUG] الملخص الناتج:\n{summary}") return summary except Exception as e: print(f"[RAG][ERROR] أثناء التلخيص: {e}") return "" def generate_answer_from_passages(self, question, passages_text): summary = self.summarize_text(passages_text) prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{summary}\n\nالسؤال: {question}\nالإجابة:" try: inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) output_ids = self.model.generate(inputs["input_ids"], max_length=200) answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() except Exception as e: print(f"[RAG][ERROR] أثناء توليد الإجابة: {e}") answer = "" return answer, summary