namberino commited on
Commit
ffd37d2
Β·
1 Parent(s): 2f5ff1a
Files changed (2) hide show
  1. enhanced_rag_mcq.py +13 -106
  2. fastapi_app.py +0 -12
enhanced_rag_mcq.py CHANGED
@@ -657,43 +657,6 @@ class EnhancedRAGMCQGenerator:
657
  print(f"Raw response: {response[:500]}...")
658
  raise ValueError(f"Failed to parse LLM response: {e}")
659
 
660
- def _batch_invoke(self, prompts: List[str]) -> List[str]:
661
- if not prompts:
662
- return []
663
-
664
- # Try to use transformers pipeline (batch mode)
665
- pl = getattr(self.llm, "pipeline", None)
666
- if pl is not None:
667
- try:
668
- # Call the pipeline with a list. Transformers will return a list of generation outputs.
669
- raw_outputs = pl(prompts)
670
-
671
- responses = []
672
- for out in raw_outputs:
673
- # The pipeline may return either a dict (single result) or a list of dicts (if return_full_text or num_return_sequences was set)
674
- if isinstance(out, list) and out:
675
- text = out[0].get("generated_text", "")
676
- elif isinstance(out, dict):
677
- text = out.get("generated_text", "")
678
- else:
679
- # fallback: coerce to string
680
- text = str(out)
681
- responses.append(text)
682
-
683
- if len(responses) == len(prompts):
684
- return responses
685
- else:
686
- print("⚠️ Batch pipeline returned unexpected shape β€” falling back")
687
- except Exception as e:
688
- # Batch mode failed. Fall back to sequential invocations.
689
- print(f"⚠️ Batch invoke failed: {e}. Falling back to sequential.")
690
-
691
- # Sequential invocation to preserve behavior
692
- results = []
693
- for p in prompts:
694
- results.append(self.llm.invoke(p))
695
- return results
696
-
697
  def generate_batch(self,
698
  topics: List[str],
699
  question_per_topic: int = 5,
@@ -707,9 +670,8 @@ class EnhancedRAGMCQGenerator:
707
  if question_types is None:
708
  question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
709
 
 
710
  total_questions = len(topics) * question_per_topic
711
- prompt_metadatas = [] # stores tuples (topic, difficulty, question_type)
712
- formatted_prompts = []
713
 
714
  print(f"🎯 Generating {total_questions} MCQs...")
715
 
@@ -717,76 +679,21 @@ class EnhancedRAGMCQGenerator:
717
  print(f"πŸ“ Processing topic {i+1}/{len(topics)}: {topic}")
718
 
719
  for j in range(question_per_topic):
720
- difficulty = difficulties[j % len(difficulties)]
721
- question_type = question_types[j % len(question_types)]
722
-
723
- query = topic
724
-
725
- # retrieve context once per prompt
726
- contexts = self.retriever.retrieve_diverse_contexts(query, k=self.config.get("k", 5)) if hasattr(self, "retriever") else []
727
- context_text = "\n\n".join([d.page_content for d in contexts]) if contexts else topic
728
-
729
- # Build prompt with PromptTemplateManager
730
- prompt_template = self.template_manager.get_template(question_type)
731
- prompt_input = {
732
- "context": context_text,
733
- "topic": topic,
734
- "difficulty": difficulty.value if hasattr(difficulty, 'value') else str(difficulty),
735
- "question_type": question_type.value if hasattr(question_type, 'value') else str(question_type)
736
- }
737
- formatted = prompt_template.format(**prompt_input)
738
-
739
- # Length check and fallback
740
- is_safe, token_count = self.check_prompt_length(formatted)
741
- if not is_safe:
742
- truncated = formatted[: self.config.get("max_prompt_chars", 2000)]
743
- formatted = truncated
744
-
745
- prompt_metadatas.append((topic, difficulty, question_type))
746
- formatted_prompts.append(formatted)
747
-
748
- total = len(formatted_prompts)
749
- if total == 0:
750
- return []
751
 
752
- print(f"πŸ“¦ Sending {total} prompts to the LLM in batch mode (if supported)")
753
- start_t = time.time()
754
- raw_responses = self._batch_invoke(formatted_prompts)
755
- elapsed = time.time() - start_t
756
- print(f"⏱ LLM batch time: {elapsed:.2f}s for {total} prompts")
757
 
758
- # Parse raw responses back into MCQQuestion objects
759
- mcqs = []
760
- for meta, response in zip(prompt_metadatas, raw_responses):
761
- topic, difficulty, question_type = meta
762
- try:
763
- response_data = self._extract_json_from_response(response)
764
-
765
- # Reconstruct MCQQuestion
766
- options = []
767
- for label, text in response_data["options"].items():
768
- is_correct = label == response_data["correct_answer"]
769
- options.append(self.MCQOption(label=label, text=text, is_correct=is_correct))
770
-
771
- mcq = self.MCQQuestion(
772
- question_id=response_data.get("id", None),
773
- topic=topic,
774
- question_text=response_data["question"],
775
- options=options,
776
- explanation=response_data.get("explanation", ""),
777
- difficulty=(difficulty if hasattr(difficulty, 'name') else difficulty),
778
- question_type=(question_type if hasattr(question_type, 'name') else question_type),
779
- confidence_score=response_data.get("confidence_score", 0.0)
780
- )
781
-
782
- if hasattr(self, 'validator'):
783
- mcq = self.validator.calculate_quality_score(mcq)
784
-
785
- mcqs.append(mcq)
786
- except Exception as e:
787
- print(f"❌ Failed parsing response for topic={topic}: {e}")
788
 
789
- print(f"πŸŽ‰ Generated {len(mcqs)}/{total} MCQs successfully (batched)")
790
  return mcqs
791
 
792
  def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
 
657
  print(f"Raw response: {response[:500]}...")
658
  raise ValueError(f"Failed to parse LLM response: {e}")
659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  def generate_batch(self,
661
  topics: List[str],
662
  question_per_topic: int = 5,
 
670
  if question_types is None:
671
  question_types = [QuestionType.DEFINITION, QuestionType.APPLICATION]
672
 
673
+ mcqs = []
674
  total_questions = len(topics) * question_per_topic
 
 
675
 
676
  print(f"🎯 Generating {total_questions} MCQs...")
677
 
 
679
  print(f"πŸ“ Processing topic {i+1}/{len(topics)}: {topic}")
680
 
681
  for j in range(question_per_topic):
682
+ try:
683
+ # Cycle through difficulties and question types
684
+ difficulty = difficulties[j % len(difficulties)]
685
+ question_type = question_types[j % len(question_types)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
+ mcq = self.generate_mcq(topic, difficulty, question_type)
688
+ mcqs.append(mcq)
 
 
 
689
 
690
+ print(f" βœ… Generated question {j+1}/{question_per_topic} "
691
+ f"(Quality: {mcq.confidence_score:.1f})")
692
+
693
+ except Exception as e:
694
+ print(f" ❌ Failed to generate question {j+1}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
 
696
+ print(f"πŸŽ‰ Generated {len(mcqs)}/{total_questions} MCQs successfully")
697
  return mcqs
698
 
699
  def export_mcqs(self, mcqs: List[MCQQuestion], output_path: str):
fastapi_app.py CHANGED
@@ -68,14 +68,10 @@ async def mcq_gen(
68
  ):
69
  if not generator:
70
  raise HTTPException(status_code=500, detail="Generator not initialized")
71
-
72
- print("generator initialized")
73
 
74
  topic_list = [t.strip() for t in topics.split(',') if t.strip()]
75
  if not topic_list:
76
  raise HTTPException(status_code=400, detail="At least one topic must be provided")
77
-
78
- print("topic validated")
79
 
80
  # Validate and convert enum values
81
  try:
@@ -84,16 +80,12 @@ async def mcq_gen(
84
  valid_difficulties = [d.value for d in DifficultyLevel]
85
  raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
86
 
87
- print("difficulty validated")
88
-
89
  try:
90
  qtype_enum = QuestionType(qtype.lower())
91
  except ValueError:
92
  valid_types = [q.value for q in QuestionType]
93
  raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
94
 
95
- print("question type validated")
96
-
97
  # Save uploaded PDF to temporary folder
98
  filename = file.filename if file.filename else "uploaded_file"
99
  file_path = os.path.join(tmp_folder, filename)
@@ -101,8 +93,6 @@ async def mcq_gen(
101
  shutil.copyfileobj(file.file, buffer)
102
  file.file.close()
103
 
104
- print(f"file read and written to {file_path}")
105
-
106
  try:
107
  # Load and index the uploaded document
108
  docs, _ = generator.load_documents(tmp_folder)
@@ -110,8 +100,6 @@ async def mcq_gen(
110
  except Exception as e:
111
  raise HTTPException(status_code=500, detail=f"Document processing error: {e}")
112
 
113
- print("loaded and indexed the document")
114
-
115
  start_time = time.time()
116
  try:
117
  mcqs = generator.generate_batch(
 
68
  ):
69
  if not generator:
70
  raise HTTPException(status_code=500, detail="Generator not initialized")
 
 
71
 
72
  topic_list = [t.strip() for t in topics.split(',') if t.strip()]
73
  if not topic_list:
74
  raise HTTPException(status_code=400, detail="At least one topic must be provided")
 
 
75
 
76
  # Validate and convert enum values
77
  try:
 
80
  valid_difficulties = [d.value for d in DifficultyLevel]
81
  raise HTTPException(status_code=400, detail=f"Invalid difficulty. Must be one of: {valid_difficulties}")
82
 
 
 
83
  try:
84
  qtype_enum = QuestionType(qtype.lower())
85
  except ValueError:
86
  valid_types = [q.value for q in QuestionType]
87
  raise HTTPException(status_code=400, detail=f"Invalid question type. Must be one of: {valid_types}")
88
 
 
 
89
  # Save uploaded PDF to temporary folder
90
  filename = file.filename if file.filename else "uploaded_file"
91
  file_path = os.path.join(tmp_folder, filename)
 
93
  shutil.copyfileobj(file.file, buffer)
94
  file.file.close()
95
 
 
 
96
  try:
97
  # Load and index the uploaded document
98
  docs, _ = generator.load_documents(tmp_folder)
 
100
  except Exception as e:
101
  raise HTTPException(status_code=500, detail=f"Document processing error: {e}")
102
 
 
 
103
  start_time = time.time()
104
  try:
105
  mcqs = generator.generate_batch(