pope30 / rag_pipeline.py
ramy2018's picture
Update rag_pipeline.py
7d3c4f0 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, models
import numpy as np
import torch
class RAGPipeline:
def __init__(self):
print("[RAG] تحميل النماذج...")
word_embedding_model = models.Transformer('asafaya/bert-base-arabic')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
# ✅ نموذج مخصص للتلخيص العربي
self.tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
self.model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
self.chunks = []
self.embeddings = None
print("[RAG] تم تحميل النماذج بنجاح.")
def build_index(self, chunks):
self.chunks = chunks
self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
def retrieve_passages(self, question, top_k=5):
if self.embeddings is None or len(self.chunks) == 0:
return []
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
similarities = np.dot(self.embeddings, question_embedding.T).squeeze()
top_indices = similarities.argsort()[-top_k:][::-1]
return [self.chunks[i] for i in top_indices]
def summarize_text(self, text):
print("[RAG][INPUT TO SUMMARIZE]:", text)
prompt = f"summarize: {text}"
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
summary_ids = self.model.generate(inputs["input_ids"], max_length=128)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
print(f"[RAG][DEBUG] الملخص الناتج:\n{summary}")
return summary
except Exception as e:
print(f"[RAG][ERROR] أثناء التلخيص: {e}")
return ""
def generate_answer_from_passages(self, question, passages_text):
summary = self.summarize_text(passages_text)
prompt = f"أجب عن السؤال التالي بناء على النص:\n\n{summary}\n\nالسؤال: {question}\nالإجابة:"
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
output_ids = self.model.generate(inputs["input_ids"], max_length=200)
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"[RAG][ERROR] أثناء توليد الإجابة: {e}")
answer = ""
return answer, summary