Update rag_pipeline.py
Browse files- rag_pipeline.py +5 -3
rag_pipeline.py
CHANGED
@@ -35,9 +35,10 @@ class RAGPipeline:
|
|
35 |
prompt = f"لخص النص التالي:\n{text}"
|
36 |
try:
|
37 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
38 |
-
summary_ids = self.model.generate(inputs["input_ids"], max_length=
|
39 |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
40 |
-
except:
|
|
|
41 |
return ""
|
42 |
|
43 |
def generate_answer_from_passages(self, question, passages_text):
|
@@ -48,7 +49,8 @@ class RAGPipeline:
|
|
48 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
49 |
output_ids = self.model.generate(inputs["input_ids"], max_length=200)
|
50 |
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
51 |
-
except:
|
|
|
52 |
answer = ""
|
53 |
|
54 |
return answer, summary
|
|
|
35 |
prompt = f"لخص النص التالي:\n{text}"
|
36 |
try:
|
37 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
38 |
+
summary_ids = self.model.generate(inputs["input_ids"], max_length=128)
|
39 |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
40 |
+
except Exception as e:
|
41 |
+
print(f"[RAG][ERROR] أثناء التلخيص: {e}")
|
42 |
return ""
|
43 |
|
44 |
def generate_answer_from_passages(self, question, passages_text):
|
|
|
49 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
50 |
output_ids = self.model.generate(inputs["input_ids"], max_length=200)
|
51 |
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
52 |
+
except Exception as e:
|
53 |
+
print(f"[RAG][ERROR] أثناء توليد الإجابة: {e}")
|
54 |
answer = ""
|
55 |
|
56 |
return answer, summary
|