ramy2018 commited on
Commit
c04d662
·
verified ·
1 Parent(s): 4932f0a

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +5 -3
rag_pipeline.py CHANGED
@@ -35,9 +35,10 @@ class RAGPipeline:
35
  prompt = f"لخص النص التالي:\n{text}"
36
  try:
37
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
38
- summary_ids = self.model.generate(inputs["input_ids"], max_length=256)
39
  return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
40
- except:
 
41
  return ""
42
 
43
  def generate_answer_from_passages(self, question, passages_text):
@@ -48,7 +49,8 @@ class RAGPipeline:
48
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
49
  output_ids = self.model.generate(inputs["input_ids"], max_length=200)
50
  answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
51
- except:
 
52
  answer = ""
53
 
54
  return answer, summary
 
35
  prompt = f"لخص النص التالي:\n{text}"
36
  try:
37
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
38
+ summary_ids = self.model.generate(inputs["input_ids"], max_length=128)
39
  return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
40
+ except Exception as e:
41
+ print(f"[RAG][ERROR] أثناء التلخيص: {e}")
42
  return ""
43
 
44
  def generate_answer_from_passages(self, question, passages_text):
 
49
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
50
  output_ids = self.model.generate(inputs["input_ids"], max_length=200)
51
  answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
52
+ except Exception as e:
53
+ print(f"[RAG][ERROR] أثناء توليد الإجابة: {e}")
54
  answer = ""
55
 
56
  return answer, summary