Update rag_pipeline.py
Browse files- rag_pipeline.py +2 -2
rag_pipeline.py
CHANGED
@@ -15,7 +15,7 @@ class RAGPipeline:
|
|
15 |
self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
16 |
|
17 |
self.chunks = []
|
18 |
-
self.embeddings =
|
19 |
|
20 |
print("[RAG] تم تحميل النماذج بنجاح.")
|
21 |
|
@@ -24,7 +24,7 @@ class RAGPipeline:
|
|
24 |
self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
|
25 |
|
26 |
def retrieve_passages(self, question, top_k=5):
|
27 |
-
if
|
28 |
return []
|
29 |
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
|
30 |
similarities = np.dot(self.embeddings, question_embedding.T).squeeze()
|
|
|
15 |
self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
16 |
|
17 |
self.chunks = []
|
18 |
+
self.embeddings = None
|
19 |
|
20 |
print("[RAG] تم تحميل النماذج بنجاح.")
|
21 |
|
|
|
24 |
self.embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
|
25 |
|
26 |
def retrieve_passages(self, question, top_k=5):
|
27 |
+
if self.embeddings is None or len(self.chunks) == 0:
|
28 |
return []
|
29 |
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
|
30 |
similarities = np.dot(self.embeddings, question_embedding.T).squeeze()
|