ramy2018 commited on
Commit
497c142
·
verified ·
1 Parent(s): 80609fb

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +2 -2
rag_pipeline.py CHANGED
@@ -15,7 +15,7 @@ class RAGPipeline:
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
16
 
17
  self.chunks = []
18
- self.embeddings = []
19
 
20
  print("[RAG] تم تحميل النماذج بنجاح.")
21
 
@@ -24,7 +24,7 @@ class RAGPipeline:
24
  self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
25
 
26
  def retrieve_passages(self, question, top_k=5):
27
- if not self.embeddings or not self.chunks:
28
  return []
29
  question_embedding = self.embedder.encode([question], convert_to_numpy=True)
30
  similarities = np.dot(self.embeddings, question_embedding.T).squeeze()
 
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
16
 
17
  self.chunks = []
18
+ self.embeddings = None
19
 
20
  print("[RAG] تم تحميل النماذج بنجاح.")
21
 
 
24
  self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
25
 
26
  def retrieve_passages(self, question, top_k=5):
27
+ if self.embeddings is None or len(self.chunks) == 0:
28
  return []
29
  question_embedding = self.embedder.encode([question], convert_to_numpy=True)
30
  similarities = np.dot(self.embeddings, question_embedding.T).squeeze()